clear all ;
close all ;

% lecture de l'image
im = imread('serre.png') ;
figure (1) ;
imagesc(im) ; axis equal ;
title('image de depart') ;

N = size(im,1) ;

% determination des pixels rouges
inds = find((im(:,:,1)==255).*(im(:,:,2)==0)) ;
inds_c = find((im(:,:,1)<255)+(im(:,:,2)>0)) ;

im = double(mean(im,3)) ;

lambda = 6 ;
[x,aphi] = wavelet_transform(im) ; % initialisation de x
J = size(x,2) ;

% iterations
for kit=1:500
    
    inv_x = inverse_wavelet_transform(x,aphi) ;
    inv_x(inds_c) = im(inds_c) ;
    [x,aphi] = wavelet_transform(inv_x) ;
    for j=1:J
        x{j} = sign(x{j}).*max(0,abs(x{j})-lambda) ;        
    end
    aphi = sign(aphi)*max(0,abs(aphi)-lambda) ;    
    
end

% image reconstruite et affichage
im_rec = inverse_wavelet_transform(x,aphi) ;
figure(2) ;
imagesc(im_rec,[0,255]) ;
title('image reconstruite') ;
colormap(gray(256)) ;
axis equal ;


% amelioration par petites translations et moyennage
bool = zeros(N) ;
bool(inds) = 1 ;

lambda = 3 ;
K = 0 ;
R = 1 ;
for r1=-R:R
    for r2 = -R:R

        im2 = circshift(im,[r1,r2]) ;
        inds_c = find(1-circshift(bool,[r1,r2])) ;
        
        K = K+1 ;
        for k=1:500
            
            inv_x = inverse_wavelet_transform(x,aphi) ;
            inv_x(inds_c) = im2(inds_c) ;
            [x,aphi] = wavelet_transform(inv_x) ;
            for j=1:J
                x{j} = sign(x{j}).*max(0,abs(x{j})-lambda) ;        
            end
            aphi = sign(aphi)*max(0,abs(aphi)-lambda) ;    
            
        end

        sauv(:,:,K) = circshift(inverse_wavelet_transform(x,aphi), ...
                                [-r1,-r2]) ;
    end
end

im_rec = mean(sauv,3) ;
    
figure(3) ;
imagesc(im_rec,[0,255]) ;
title('image reconstruite avec moyennage') ;
axis equal ;
colormap(gray(256)) ;

im = imread('serre_original.png') ;
im = double(im) ;
if (size(im,3)>1)
    im = mean(im,3) ;
end
figure(4) ;
imagesc(im,[0,255]) ;
title('image originale') ;
axis equal ;
colormap(gray(256)) ;
