%% 

addpath FOCUSS_utils/
addpath ksvdbox13/
addpath ksvdbox13/ompbox10/


% load data
load data_slice30_HCP28         % training data from subject HCP28
load msk_slice30_HCP28          % brain mask for HCP28

load data_slice40_HCP25         % test data from subject HCP28
load msk_slice40_HCP25          % brain mask for HCP28


% q-space locations
q = textread('qvecs_515.txt');  
q = 5*round(q*100)/100;         


%% select pdfs for training from 1 slice

side_len = 12;          % zero padding in q-space
DC_offset = 7;


cart_mask = zeros(side_len,side_len,side_len);
cart_grid = round(q) + DC_offset;           % DC at (7,7)
ind = sub2ind(size(cart_mask),cart_grid(:,1),cart_grid(:,2),cart_grid(:,3)); 
cart_mask(ind) = 1;
DFT_all = FT_v2(cart_mask);                 % DFT operator corresponding to all sampled q-space points



PDFs = [];
count = 1;

data_slice = data_slice30_HCP28;
msk_slice = msk_slice30_HCP28;
       
figure(1), imagesc( data_slice(:,:,1) ), axis image, title('training data'), colormap hot, colorbar, 
    

for ay = 1:size(data_slice,1)
    for cey = 1:size(data_slice,2)
        if msk_slice(ay,cey) > 0 
            data_1d = squeeze(data_slice(ay,cey,:));
            q_space = zeros(side_len,side_len,side_len);
            q_space(ind) = data_1d;  

            pdf = DFT_all'*q_space;
            PDFs(:,count) = pdf(:);
            count = count + 1;
        end
    end
end

tile_pdf(reshape(mean(PDFs,2),[side_len,side_len,side_len]),3,4,2,'average pdf from traning slice',[0,.5]), colormap jet



%% load sampling mask

load mask_cs_R3

DFT_cs = FT_v2(mask_cs);

tile_pdf(mask_cs .* cart_mask, 3,4,2,['reduced mask, R = ',num2str(1/mean(mask_cs(cart_mask==1)))],[0,1])



%% train dictionary using k-svd toolbox

Dictionary_Size = 258;

params.data = real(PDFs);
%params.Tdata = 20;                             % sparsity level of each example in the basis
sigma = 1e-2;                                   % noise std
params.Edata = sigma * sqrt(size(PDFs,1));      % suggested value = sigma * sqrt(signal_dimension)
                                                % this should be set so that mean atomnum is in [0,5]
params.dictsize = Dictionary_Size;       

params.iternum = 30;
params.memusage = 'high';

tic
    [Dksvd,g,err] = ksvd(params,1);
toc


figure(3), imagesc(Dksvd), axis square, colormap hot


%% run L1-FOCUSS and Dictionary-FOCUSS on test slice


focuss_iter = 5;       % number of outer loops
cg_iter = 30;          % number of inner loops

lambda = 0e-4;         % regularization parameter (set to zero for no regularization)
p = 1;                 % Lp - norm to use in FOCUSS



pdf_result_size = [12,12,12,96,96];


PDF_DICT = zeros(pdf_result_size);
PDF_L1 = zeros(pdf_result_size);



test_slice = data_slice40_HCP25;
mask_test_slice = msk_slice40_HCP25;

figure(1), imagesc( test_slice(:,:,1) ), axis image, colormap hot; colorbar, drawnow


total_voxels = sum(mask_test_slice(:));

  

dict_rmse = zeros(96);
L1_rmse = zeros(96);

count = 1;

tic
for ay = 1:size(test_slice,1)
    for cey = 1:size(test_slice,2)       

        if mask_test_slice( ay,cey)

            data_1d = squeeze(test_slice(ay,cey,:));
            q_space = zeros(side_len,side_len,side_len);
            q_space(ind) = data_1d;  

            pdf_true = DFT_all' * q_space; 
            q_space_reduced = DFT_cs * pdf_true;
            pdf_reduced = DFT_cs' * q_space_reduced;


            % L1 recon
            f_res = focuss_DSI( pdf_reduced, q_space_reduced, DFT_cs, lambda, cg_iter, focuss_iter, p, pdf_true, 1 );
            PDF_L1(:,:,:,ay,cey) = f_res;
            L1_rmse(ay,cey) = 100 * norm(f_res(:)-pdf_true(:)) / norm(pdf_true(:));
        
            % dictionary recon
            f_dict0 = Dksvd'*pdf_reduced(:);
            f_dict = focuss_DSI_dictonary( f_dict0, q_space_reduced, DFT_cs, Dksvd, lambda, cg_iter, focuss_iter, p );

            pdf_dict = reshape(Dksvd * f_dict, size(pdf_true));
            pdf_subs = DFT_all' * ((DFT_all * pdf_dict) .* (1-mask_cs) + q_space_reduced);          

            PDF_DICT(:,:,:,ay,cey) = pdf_subs;
            dict_rmse(ay,cey) = 100 * norm(pdf_subs(:)-pdf_true(:)) / norm(pdf_true(:));


            disp(['Voxels: ', num2str(count),' / ', num2str(total_voxels), '  Dictionary RMSE: ', num2str(dict_rmse(ay,cey)), '  L1 RMSE: ', num2str(L1_rmse(ay,cey))])

            count = count + 1;
        end
    end
end
toc

mean_dict = mean(dict_rmse(mask_test_slice));
mean_L1 = mean(L1_rmse(mask_test_slice));
figure(2), imagesc([imrotate(L1_rmse,-90), imrotate(dict_rmse,-90)], [0,20]), axis image,  ...
    title(['Average L1 RMSE: ', num2str(mean_L1), ' %     Average Dictionary RMSE: ', num2str(mean_dict), ' %'])




