%-------------------------------------------------------------------------%
% Function to compute total transmission matrix of tensors A_T for
% multi-layered system
% Created by X.Kuci, TU/e, 2024
%-------------------------------------------------------------------------%
function A_T = get_transm_tensor(layers, f, k0, theta0, e_x, dx)
%--------------------------------------------------------------------------
% Computes the 2x2 total transmission tensor (A_T) for an N-layer system.
% Handles any mix of homogeneous ('hom') and enriched ('enriched') layers.
%
% INPUTS:
%   layers: struct array for each layer with fields:
%       .type : 'hom' or 'enriched'
%       .HM   : material struct (precomputed)
%       .W    : unit cell width
%       .N    : number of unit cells
%   f     : frequency [Hz]
%   k0    : incident wavenumber (input medium)
%   theta0: incident angle [rad]
%   e_x   : direction vector (usually [1;0])
%
% OUTPUT:
%   A_T   : 2x2 transmission tensor mapping input to output at the interface
%--------------------------------------------------------------------------

nLayers = numel(layers);

nLayers = numel(layers);

% Precompute all B0, Bell, and P for layers 2..nLayers 
for r = 2:nLayers
    ell = layers(r).N * layers(r).W;
    if strcmpi(layers(r).type, 'hom')
        [B0, Bell, ~] = get_B_hom(layers(r).HM, layers(r).W, f, k0, theta0, ell, e_x, dx);
    else
        [B0, Bell] = get_B(layers(r).HM, layers(r).W, f, k0, theta0, ell, e_x);
    end
    layers(r).B0   = B0;
    layers(r).Bell = Bell;

    if r == nLayers && strcmpi(layers(r).type, 'hom')
        [~, ~, P_N] = get_B_hom(layers(r).HM, layers(r).W, f, k0, theta0, ell, e_x, dx);
        layers(r).P = P_N;
    end
end

% Recursive chain
A = eye(size(layers(2).B0));  
for r = 2:numel(layers)-1  
    A = layers(r).Bell * (layers(r).B0 \ A);
end
A = layers(end).B0 \ A;      


% Schur complement: extract displacement submatrix 
A_uu = A(1:2,1:2);
A_ut = A(1:2,3:4);
A_tu = A(3:4,1:2);
A_tt = A(3:4,3:4);
A_TT = A_uu - A_ut * (A_tt \ A_tu);

% Projection (output layer, homogeneous and semi-infinite) 

A_T = layers(nLayers).P * A_TT;

end
