function  [x,level] = ldwt(x, operation, wavelet, extension, minsize)
%LDWT 1-D and 2-D Discrete Wavelet Transform by Lifting Scheme
%
% Features:
%   If X is a (n x m) matrix
%   ldwt(X)           - Forward Wavelet Transform
%   ldwt(X,'forward') - the same as above
%   ldwt(X,'inverse') - Inverse Wavelet Transform
%   ldwt(X,'adjoint') - Adjoint of the Wavelet Transform
%   ldwt(X,'inverseadjoint') - Adjoint of the Inverse Wavelet Transform
%
% Full Argument List
%   ldwt(X, operation, wavelet, extension, minsize)
%     * operation may be 'forward' (default), 'inverse', 'adjoint'
%     * wavelet may be 'haar', 'legall53' (default), 'cdf97'
%     * extension may be 'symmetric', 'zeropadded' (default)
%     * minsize is the size of a dimension for which no longer a transform
%       is performed. Sensible values are 1...size(x,1). Default is 8.
%       I.e. for vectors of even size N the setting minsize to N/2 results
%       in a single wavelet decomposition level.
%
% Remarks:
%
%   [Wx, l] = ldwt(X) - l stores the sizes of the coarse coefficients
%   for example if l = [32 32; 16 16; 8 8]. Then the coarsest coefficients
%   are stored in Wx(1:8,1:8), the coefficients of the next finer level
%   are in Wx(1:16,1:16). Finally the finest level (i.e. original size) is
%   Wx(1:32,1:32) = Wx;
%
%   The adjoint is such that with
%   Wx = ldwt(x); Wstary = ldwt(y,'adjoint');
%   we have
%   Wx(:)'*y(:) == x(:)'*Wstary(:)
%
%   If periodic extension is used then the results are different from the
%   results with MATLABs lwt function.
%
%   For sublevels of odd size the results are different than with
%   MATLABs lwt function.
%
%   The detail coefficients have opposite sign to the detail coefficients
%   of MATLABs dwt command.
%
% See also: LWT, LIFTWAVE, DWT, WAVEINFO

%   Copyright 2013 Kamil S. Kazimierski


if (nargin <= 1) operation = 'forward';end; %#ok<SEPEX>
if (nargin <= 2) wavelet = 'legall53';end; %#ok<SEPEX>
if (nargin <= 3) extension = 'zeropadded';end; %#ok<SEPEX>
if (nargin <= 4) minsize = 8;end; %#ok<SEPEX>

ls = lanalysis(wavelet,extension);
if strcmpi(operation,'forward');
    [x,level] = wavedecnd(x,ls,minsize);
elseif strcmpi(operation,'inverse');
    x = wavedecnd_inv(x,ls,minsize);
elseif strcmpi(operation,'adjoint');
    x = wavedecnd_star(x,ls,minsize);
elseif strcmpi(operation,'inverseadjoint') || strcmpi(operation,'adjointinverse')
    x = wavedecnd_star_inv(x,ls,minsize);
else
    error('Unknown operation');
end
end

function [x, l] = wavedecnd(x,ls,m)
level = sublevel_sizes_all_sublevels(size(x),m);
l = level;
while size(level,1) > 1
   N1vec = level(end,:);
   I = selector(N1vec);
   level = level(1:end-1,:);
   x(I{:}) = lwtnd(x(I{:}),ls,m);
end
end
function x = wavedecnd_inv(x,ls,m)
level = sublevel_sizes_all_sublevels(size(x),m);
while size(level,1) > 1
   level = level(2:end,:);
   N1vec = level(1,:);
   I = selector(N1vec);
   x(I{:}) = lwtnd_inv(x(I{:}),ls,m);
end
end
function x = wavedecnd_star(x,ls,m)
level = sublevel_sizes_all_sublevels(size(x),m);
while size(level,1) > 1
   level = level(2:end,:);
   N1vec = level(1,:);
   I = selector(N1vec);
   x(I{:}) = lwtnd_star(x(I{:}),ls,m);
end
end
function x = wavedecnd_star_inv(x,ls,m)
level = sublevel_sizes_all_sublevels(size(x),m);
while size(level,1) > 1
   N1vec = level(end,:);
   I = selector(N1vec);
   level = level(1:end-1,:);
   x(I{:}) = lwtnd_star_inv(x(I{:}),ls,m);
end
end

function [x,N1vec] = lwtnd(x,ls,m)
% We use here tensor-product real wavelets, hence
% W(X) = W1*X*W2^T = (W2 (W1(X))^T )^T
N1vec = size(x);
for i =1:ndims(x)
    s = size(x);
    x = reshape(x,s(1),prod(s(2:end)));
    if s(1) > m
        [x,N1vec(i)] = lwt1d(x,ls);
    end;
    x = reshape(x,s);
    x = shiftdim(x,1);
end;
end
function [x,N1vec] = lwtnd_inv(x,ls,m)
N1vec = size(x);
for i =1:ndims(x)
    x = shiftdim(x,ndims(x)-1);
    s = size(x);
    x = reshape(x,s(1),prod(s(2:end)));
    if s(1) > m
        [x,N1vec(i)] = lwt1d_inv(x,ls);
    end;
    x = reshape(x,s);
end;
end
function [x,N1vec] = lwtnd_star(x,ls,m)
% This function is implemented such that with inner product
% inner = @(x,y) sum(sum(conj(x).*y,1),2);
% inner(fwt53_per_star(x),y) == inner(x,fwt53_per(y))
N1vec = size(x);
for i =1:ndims(x)
    x = shiftdim(x,ndims(x)-1);
    s = size(x);
    x = reshape(x,s(1),prod(s(2:end)));
    if s(1) > m
        [x,N1vec(i)] = lwt1d_star(x,ls);
    end;
    x = reshape(x,s);
end;
end
function [x,N1vec] = lwtnd_star_inv(x,ls,m)
N1vec = size(x);
for i =1:ndims(x)
    x = shiftdim(x,1);
    s = size(x);
    x = reshape(x,s(1),prod(s(2:end)));
    if s(1) > m
        [x,N1vec(i)] = lwt1d_star_inv(x,ls);
    end;
    x = reshape(x,s);
end;
end

%--------------------------------------------------------------------------
%--- Lifting Scheme -------------------------------------------------------
%--------------------------------------------------------------------------

function ls = lanalysis(scheme,extension)
% The scheme is encoded in an (N x 3) matrix
% in the first N-1 entries
% ls(k,3) is the type of the lifting step
%    * 0 == high/primal lifting step with filter c+d*z^(-1)
%    * 1 == low/dual lifting step with filter a*z + b
%   (i.e. it is the degree of the term of the first coefficient in the
%    z-domain sum c_n * z^(-n)   )
% ls(k,1:2) are the lifting coefficients a,b resp. c,d
% ls(N,3) is the type of extension
%    * 0 == zero padding
%    * 1 == symmetric extension ( C B [A B C D] C B )
% ls(N,1) normalization constant for x_0 (x(1), x(3) etc) / coarse
%    coefficents
% 1/ls(N,2) (mind the reciplocal!) normalization constant for x_1 / detail
%    coefficients
if strcmpi(scheme,'haar');
    % BEWARE: This function implements the Haar-transform in the version of
    % lwt, i.e. [c,d] = lwt(x,'haar'); [c;d] == fwthaar_row(x),
    % (for signals of even length)
    %    1.) The output is different to [c,d]=dwt(x,'haar'); in particular
    %        in the dwt implementation the d-coefficients are the opposite
    %        sign to the d-coefficients of lwt.
    K=sqrt(2.0);
    ls = [0.0 -1.0  1;...
          0.5  0.0  0;...
          K    K    0];
elseif strcmpi(scheme,'legall53')
    % BEWARE: This function implements the LeGall 5/3-Transform (cdf2.2),
    % however MATLABs lwt uses zero-padding extension
    a = -0.5;
    b = 0.25;
    K = sqrt(2.0);
    ls = [a a  1;...
          b b  0;...
          K K  0];
elseif strcmpi(scheme,'cdf97')
    % BEWARE: This function implements the CDF 9/7-Transform (cdf4.4),
    % however MATLABs lwt uses other polyphase decomposition
    a = -1.586134342E+0;
    b = -5.298011854E-2;
    c =  0.8829110762E+0;
    d =  0.4435068522E+0;
    K = 1.149604398E0;
    ls = [a a  1;...
          b b  0;...
          c c  1;...
          d d  0;...
          K K  0];
else
    error('Unknown lifting scheme.');
end
if strcmpi(extension,'symmetric')
    ls(end,end) = 1;
elseif strcmpi(extension, 'zeropadded')
    ls(end,end) = 0;
else
    error('Unknown extension');
end
end
function [x,N1] = lwt1d(x,ls)
N  = size(x,1);
N1 = sublevel_sizes(N);
% --- Predict / Update ----------------------------------------------------
extension = ls(end,end);
for i = 1:size(ls,1)-1
    type = ls(i,3);
    filt = ls(i,1:2);
    if (extension == 0)
        eb = 0; ee = 0;
    elseif (extension == 1)
        if (type == 0) %high
            [eb,ee] = extension_high(filt(1),filt(2),N);
        else %low
            [eb, ee] = extension_low(filt(1),filt(2),N);
        end;
    else
        error('Unknown extension');
    end;

    if (type == 0) %high
        x = lift_row_high(x, eb, filt(1), filt(2), ee);
    elseif (type == 1) %low
        x = lift_row_low(x, eb, filt(1), filt(2), ee);
    else
        error('Unknown step');
    end;
end
% --- Scale ---------------------------------------------------------------
K = ls(end,1); L = ls(end,2); x = scale_row(x, K,L);
% --- Pack ----------------------------------------------------------------
x = pack_row(x);
end
function [x,N1] = lwt1d_star(x,ls)
N  = size(x,1);
N1 = sublevel_sizes(N);
% --- Adjoint Pack --------------------------------------------------------
x = pack_row_star(x);
% --- Adjoint Scale -------------------------------------------------------
K = ls(end,1); L = ls(end,2); x = scale_row_star(x, K,L);
% --- Adjoint Predict / Update --------------------------------------------
extension = ls(end,end);
for i = size(ls,1)-1:-1:1
    type = ls(i,3);
    filt = ls(i,1:2);

    if (extension == 0)
        eb = 0; ee = 0;
    elseif (extension == 1)
        if (type == 0) %high
            [eb,ee] = extension_high(filt(1),filt(2),N);
        else %low
            [eb, ee] = extension_low(filt(1),filt(2),N);
        end;
    else
        error('Unknown extension');
    end;

    if (type == 0) %high
        x = lift_row_high_star(x, eb, filt(1), filt(2), ee);
    elseif (type == 1) %low
        x = lift_row_low_star(x, eb, filt(1), filt(2), ee);
    else
        error('Unknown step');
    end;
end
end
function [x,N1] = lwt1d_inv(x,ls)
N  = size(x,1);
N1 = sublevel_sizes(N);
% --- Undo Pack -----------------------------------------------------------
x = pack_row_inv(x);
% --- Undo Scale ----------------------------------------------------------
K = ls(end,1); L = ls(end,2); x = scale_row_inv(x, K,L);
% --- Undo Predict / Update -----------------------------------------------
extension = ls(end,end);
for i = size(ls,1)-1:-1:1
    type = ls(i,3);
    filt = ls(i,1:2);
    if (extension == 0)
        eb = 0; ee = 0;
    elseif (extension == 1)
        if (type == 0) %high
            [eb,ee] = extension_high(filt(1),filt(2),N);
        else %low
            [eb, ee] = extension_low(filt(1),filt(2),N);
        end;
    else
        error('Unknown extension');
    end;

    if (type == 0) %high
        x = lift_row_high_inv(x, eb, filt(1), filt(2), ee);
    elseif (type == 1) %low
        x = lift_row_low_inv(x, eb, filt(1), filt(2), ee);
    else
        error('Unknown step');
    end;
end
end
function [x,N1] = lwt1d_star_inv(x,ls)
N  = size(x,1);
N1 = sublevel_sizes(N);
% --- Adjoint Undo Predict / Update ---------------------------------------
extension = ls(end,end);
for i = 1:size(ls,1)-1
    type = ls(i,3);
    filt = ls(i,1:2);

    if (extension == 0)
        eb = 0; ee = 0;
    elseif (extension == 1)
        if (type == 0) %high
            [eb,ee] = extension_high(filt(1),filt(2),N);
        else %low
            [eb, ee] = extension_low(filt(1),filt(2),N);
        end;
    else
        error('Unknown extension');
    end;

    if (type == 0) %high
        x = lift_row_high_star_inv(x, eb, filt(1), filt(2), ee);
    elseif (type == 1) %low
        x = lift_row_low_star_inv(x, eb, filt(1), filt(2), ee);
    else
        error('Unknown step');
    end;
end
% --- Adjoint Undo Scale --------------------------------------------------
K = ls(end,1); L = ls(end,2); x = scale_row_star_inv(x, K,L);
% --- Adjoint Undo Pack ---------------------------------------------------
x = pack_row_star_inv(x);
end
function [eb,ee] = extension_low(a,b,N) %#ok<INUSL>
eb = 0; if (mod(N,2) == 0) ee = a; else ee = 0; end; %#ok<SEPEX>
end
function [eb,ee] = extension_high(c,d,N)
eb = d; if (mod(N,2) == 0) ee = 0; else ee = c; end; %#ok<SEPEX>
end

%--------------------------------------------------------------------------
%--- Lifting ingredients (Predict, Update, Scale, Pack) -------------------
%--------------------------------------------------------------------------

function x = pack_row_star_inv( x)
% packing is a permutation, hence the inverse is the adjoint
x = pack_row(x);
end
function x = pack_row_star( x)
x = pack_row_inv(x);
end
function x = pack_row_inv(  x)
[I1, I2, N1, N2, N] = split_index(x); %#ok<NASGU,ASGLU>
t = x(N1+1:end,:);
x(I1,:) = x(1:N1,:);
x(I2,:) = t;
end
function x = pack_row(      x)
[I1, I2, N1, N2, N] = split_index(x); %#ok<NASGU,ASGLU>
x = [x(I1,:);x(I2,:)];
end

function x = scale_row_star_inv( x, K,L)
x = scale_row(x, 1/K,1/L);
end
function x = scale_row_star( x, K,L)
x = scale_row(x, K,L);
end
function x = scale_row_inv(  x, K,L)
x = scale_row(x, 1/K,1/L);
end
function x = scale_row(      x, K,L)
[I1, I2, N1, N2, N] = split_index(x); %#ok<NASGU,ASGLU>

x(I1,:) = x(I1,:)*K;
x(I2,:) = x(I2,:)/L;
end

function x = lift_row_low_star_inv(  x, eb,a,b,ee)
x = lift_row_high(x, -eb,-b,-a,-ee);
end
function x = lift_row_low_star(  x, eb,a,b,ee)
x = lift_row_high(x, eb,b,a,ee);
end
function x = lift_row_low_inv(   x, eb,a,b,ee)
x = lift_row_low(x, -eb,-a,-b,-ee);
end
function x = lift_row_low(       x, eb,a,b,ee)
[I1, I2, N1, N2, N] = split_index(x);

if N > 2
    % at the begin
    x(I2(1),:) = x(I2(1),:) + (b+eb)*x(I1(1),:) + a*x(I1(2),:);

    x(I2(2:N2-1),:) = x(I2(2:N2-1),:) +  b*x(I1(2:N2-1),:) + a*x(I1(3:N2),:);
    % N even -> N2-1 = N1-1
    % N odd  -> N2-1 = N1-2

    % extension at the end
    if (mod(N,2) == 0)
        x(I2(N2),:) = x(I2(N2),:) + (b+ee)*x(I1(N1),:);
    else
        x(I2(N2),:) = x(I2(N2),:) + b*x(I1(N1-1),:) + (a+ee)*x(I1(N1),:);
    end;
else
    x(I2(1),:) = x(I2(1),:) + (b+eb+ee)*x(I1(1),:);
end
end

function x = lift_row_high_star_inv( x, eb,c,d,ee)
x = lift_row_low(x, -eb,-d,-c,-ee);
end
function x = lift_row_high_star( x, eb,c,d,ee)
x = lift_row_low(x, eb,d,c,ee);
end
function x = lift_row_high_inv(  x, eb,c,d,ee)
x = lift_row_high(x, -eb,-c,-d,-ee);
end
function x = lift_row_high(      x, eb,c,d,ee)
[I1, I2, N1, N2, N] = split_index(x); %#ok<ASGLU>

if N > 2
    % extension at the begin
    x(I1(1),:) = x(I1(1),:) + (c+eb)*x(I2(1),:);

    x(I1(2:N1-1),:) = x(I1(2:N1-1),:) + d*x(I2(1:N1-2),:) + c*x(I2(2:N1-1),:);
    % N even -> N1-2 = N2-2
    % N odd  -> N1-2 = N2-1

    % extension at the end
    if (mod(N,2) == 0)
        x(I1(N1),:) = x(I1(N1),:) + d*x(I2(N1-1),:) +(c+ee)*x(I2(N1),:);
    else
        x(I1(N1),:) = x(I1(N1),:) + (d+ee)*x(I2(N1-1),:);
    end;
else
    x(I1(1),:) = x(I1(1),:) + (c+eb+ee)*x(I2(1),:);
end
end

%--------------------------------------------------------------------------
%--- Functions for original size, coarse and detail level sizes -----------
%--------------------------------------------------------------------------

function [I1, I2, N1, N2, N] = split_index(x)
N = size(x,1);
[N1,N2] = sublevel_sizes(N);
I1 = 1:2:2*N1-1; %length(I1) = N1;
I2 = 2:2:2*N2;   %length(I2) = N2;
end
function [N1,N2] = sublevel_sizes(N)
N1 = ceil(N/2);
N2 = N - N1; %N2 = floor(N/2) ==> N2 <= N1
end
function level = sublevel_sizes_all_sublevels(s, m)
% m is the minimal sensible length for a wavelet decomposition, below that
% length only extension (artifacts) are transformed
level = s;
while max(s) > m
   for i = 1:length(s)
       if s(i) > m
           s(i) = sublevel_sizes(s(i)); %length of coarse level
       end;
   end;
   level = [s;level]; %#ok<AGROW>
end
end
function I = selector(N1vec)
for i = 1:length(N1vec)
    I{i} = 1:N1vec(i); %#ok<AGROW>
end
end