%% load data

load chi_phantom
load mask_phantom
load spatial_res

N = size(chi);

imagesc3d2(chi - (mask_use==0), N/2, 1, [90,90,90], [-0.12,0.12], 0, 'True Susceptibility') 
 
  
%% create dipole kernel and noisy phase

center = N/2 + 1;

[ky,kx,kz] = meshgrid(-N(2)/2:N(2)/2-1, -N(1)/2:N(1)/2-1, -N(3)/2:N(3)/2-1);

kx = (kx / max(abs(kx(:)))) / spatial_res(1);
ky = (ky / max(abs(ky(:)))) / spatial_res(2);
kz = (kz / max(abs(kz(:)))) / spatial_res(3);

k2 = kx.^2 + ky.^2 + kz.^2;

R_tot = eye(3);

kernel = fftshift( 1/3 - (kx * R_tot(3,1) + ky * R_tot(3,2) + kz * R_tot(3,3)).^2 ./ (k2 + eps) );    

noise_std = 2.9e-3;

phase_true = ifftn(kernel .* fftn(chi));
phase_use = phase_true + randn(N) * noise_std;

rmse_noise = 100 * norm(mask_use(:) .* (phase_use(:) - phase_true(:))) / norm(mask_use(:).*phase_true(:));

imagesc3d2(phase_use, N/2, 2, [90,90,90], [-0.04,0.04], 0, ['Noise RMSE: ', num2str(rmse_noise)]) 


%% TKD recon 

kthre = 0.08;       % truncation threshold

kernel_inv = zeros(N);
kernel_inv( abs(kernel) > kthre ) = 1 ./ kernel(abs(kernel) > kthre);

tic
    chi_tkd = real( ifftn( kernel_inv.* fftn(phase_use) ) ) .* mask_use; 
toc

rmse_tkd = 100 * norm(real(chi_tkd(:)).*mask_use(:) - chi(:)) / norm(chi(:));
imagesc3d2(chi_tkd .* mask_use - (mask_use==0), N/2, 3, [90,90,90], [-0.12,0.12], 0, ['TKD RMSE: ', num2str(rmse_tkd)])
 

%% Closed-form L2 recon

[kx, ky, kz] = ndgrid(0:N(1)-1, 0:N(2)-1, 0:N(3)-1);
Ex = 1 - exp(2i .* pi .* kx / N(1));
Ey = 1 - exp(2i .* pi .* ky / N(2));
Ez = 1 - exp(2i .* pi .* kz / N(3));

Ext = conj(Ex);
Eyt = conj(Ey);
Ezt = conj(Ez);

E2 = Ext .* Ex + Eyt .* Ey + Ezt .* Ez;
K2 = abs(kernel).^2;

beta = 3e-3;    % regularization parameter

tic
    chi_L2 = real( ifftn(conj(kernel) .* fftn(phase_use) ./ (K2 + beta * E2)) ) .* mask_use;
toc

rmse_L2 = 100 * norm(real(chi_L2(:)).*mask_use(:) - chi(:)) / norm(chi(:));
imagesc3d2(chi_L2 .* mask_use - (mask_use==0), N/2, 4, [90,90,90], [-0.12,0.12], 0, ['L2 RMSE: ', num2str(rmse_L2)])


%% TV ADMM recon

mu = 1e-2;              % gradient consistency
lambda = 2e-4;          % gradient L1 penalty
 
num_iter = 50;
tol_update = 1;

z_dx = zeros(N, 'single');
z_dy = zeros(N, 'single');
z_dz = zeros(N, 'single');

s_dx = zeros(N, 'single');
s_dy = zeros(N, 'single');
s_dz = zeros(N, 'single');

x = zeros(N, 'single');

kspace = fftn(phase_use);
Dt_kspace = conj(kernel) .* kspace;


tic
for t = 1:num_iter
    % update x : susceptibility estimate
    tx = Ext .* fftn(z_dx - s_dx);
    ty = Eyt .* fftn(z_dy - s_dy);
    tz = Ezt .* fftn(z_dz - s_dz);
    
    x_prev = x;
    x = ifftn( (mu * (tx + ty + tz) + Dt_kspace) ./ (eps + K2 + mu * E2) );

    x_update = 100 * norm(x(:)-x_prev(:)) / norm(x(:));
    disp(['Iter: ', num2str(t), '   Update: ', num2str(x_update)])
    
    if x_update < tol_update
        break
    end
    
    if t < num_iter
        % update z : gradient varible
        Fx = fftn(x);
        x_dx = ifftn(Ex .* Fx);
        x_dy = ifftn(Ey .* Fx);
        x_dz = ifftn(Ez .* Fx);

        z_dx = max(abs(x_dx + s_dx) - lambda / mu, 0) .* sign(x_dx + s_dx);
        z_dy = max(abs(x_dy + s_dy) - lambda / mu, 0) .* sign(x_dy + s_dy);
        z_dz = max(abs(x_dz + s_dz) - lambda / mu, 0) .* sign(x_dz + s_dz);

        % update s : Lagrange multiplier
        s_dx = s_dx + x_dx - z_dx;
        s_dy = s_dy + x_dy - z_dy;            
        s_dz = s_dz + x_dz - z_dz;            
    end
end
toc

rmse_tv = 100 * norm(real(x(:)).*mask_use(:) - chi(:)) / norm(chi(:));
imagesc3d2(real(x) .* mask_use - (mask_use==0), N/2, 5, [90,90,90], [-0.12,0.12], 0, ['TV RMSE: ', num2str(rmse_tv), '  iter : ', num2str(t)])


%% TGV ADMM recon

params.mu1 = 1e-2;                  % gradient consistency
params.mu0 = params.mu1;            % second order gradient consistency
params.alpha1 = 2e-4;               % gradient L1 penalty
params.alpha0 = 2 * params.alpha1;  % second order gradient L1 penalty

params.maxOuterIter = num_iter;
params.tol_update = tol_update;
params.N = N;

params.kspace = fftn(phase_use);
params.K = kernel;
 
out = TGV_3D_CF(params); 

rmse_tgv_cf = 100 * norm(real(out.x(:)).*mask_use(:) - chi(:)) / norm(chi(:));
imagesc3d2(real(out.x) .* mask_use - (mask_use==0), N/2, 6, [90,90,90], [-0.12,0.12], 0, ['TGV RMSE: ', num2str(rmse_tgv_cf), '  iter : ', num2str(out.iter)])


%% Zoom into reconstructions

ay = 116;
cey = 120;
zi = 52;

chi_disp = real(chi) - (mask_use==0);
fig_num = 10;
fig_title = 'True Susceptibility';

figure(fig_num), subplot(131), imagesc(imrotate(chi_disp(40:end-40,35:end-40,zi), 90), [-.2,.2]), axis image, colormap gray
figure(fig_num), subplot(132), imagesc(imrotate(squeeze(chi_disp(40:end-40,cey,:)), 90), [-.2,.2]), axis square, title(fig_title, 'color', 'w', 'fontsize', 32)
figure(fig_num), subplot(133), imagesc(fliplr(imrotate(squeeze(chi_disp(ay,35:end-40,:)), 90)), [-.2,.2]), set(gcf,'color','k'), axis square


chi_disp = real(chi_L2) - (mask_use==0);
fig_num = 11;
fig_title = 'Closed-form L2';

figure(fig_num), subplot(131), imagesc(imrotate(chi_disp(40:end-40,35:end-40,zi), 90), [-.2,.2]), axis image, colormap gray
figure(fig_num), subplot(132), imagesc(imrotate(squeeze(chi_disp(40:end-40,cey,:)), 90), [-.2,.2]), axis square, title(fig_title, 'color', 'w', 'fontsize', 32)
figure(fig_num), subplot(133), imagesc(fliplr(imrotate(squeeze(chi_disp(ay,35:end-40,:)), 90)), [-.2,.2]), set(gcf,'color','k'), axis square

chi_disp = real(x) - (mask_use==0);
fig_num = 12;
fig_title = 'TV ADMM';

figure(fig_num), subplot(131), imagesc(imrotate(chi_disp(40:end-40,35:end-40,zi), 90), [-.2,.2]), axis image, colormap gray
figure(fig_num), subplot(132), imagesc(imrotate(squeeze(chi_disp(40:end-40,cey,:)), 90), [-.2,.2]), axis square, title(fig_title, 'color', 'w', 'fontsize', 32)
figure(fig_num), subplot(133), imagesc(fliplr(imrotate(squeeze(chi_disp(ay,35:end-40,:)), 90)), [-.2,.2]), set(gcf,'color','k'), axis square


chi_disp = real(out.x) - (mask_use==0);
fig_num = 13;
fig_title = 'TGV ADMM';

figure(fig_num), subplot(131), imagesc(imrotate(chi_disp(40:end-40,35:end-40,zi), 90), [-.2,.2]), axis image
figure(fig_num), subplot(132), imagesc(imrotate(squeeze(chi_disp(40:end-40,cey,:)), 90), [-.2,.2]), axis square, title(fig_title, 'color', 'w', 'fontsize', 32)
figure(fig_num), subplot(133), imagesc(fliplr(imrotate(squeeze(chi_disp(ay,35:end-40,:)), 90)), [-.2,.2]), set(gcf,'color','k'), axis square, colormap gray


