classdef (Abstract) DLModel < Model
    % DLMODEL  A deep learning model.
    %
    % Training loop based on
    % https://mathworks.com/help/deeplearning/ug/train-network-using-custom-training-loop.html.

    properties (Access = private)
        % NETWORK  The trained network.
        network (1, 1) dlnetwork = dlnetwork(layerGraph([sequenceInputLayer(1)]));
        % VELOCITY  The velocities obtained for each learnable parameter of [network]
        % during gradient descent as a result of momentum.
        %
        % If the momentum is 0, this property is not relevant.
        velocity (:, 3) table;
        % ITERATION  The total number of iterations this model has been trained for.
        iteration (1, 1) {mustBeInteger, mustBeNonnegative} = 0;
    end

    properties (Abstract, Access = protected)
        % CLASSES  The classes that this model distinguishes.
        classes (:, 1);
    end


    methods
        function obj = DLModel(layers)
            % DLMODEL  Constructs a new model with an initialized but untrained network of
            % the given [layers].

            obj.network = dlnetwork(layerGraph(layers));
        end


        function loss = train(obj, dataset_train, epochs, mini_batch_size, initial_learn_rate, decay, momentum)
            arguments% (Input)
                obj (1, 1) DLModel;
                dataset_train (:, 2) table {mustHaveColumns(dataset_train, ["inputs", "labels"])};
                epochs (1, 1) {mustBeInteger, mustBePositive};
                mini_batch_size (1, 1) {mustBeInteger};
                initial_learn_rate (1, 1) double;
                decay (1, 1) double;
                momentum (1, 1) double;
            end
            % arguments (Output)
            %     loss (1, 1) dlarray;
            % end

            mbq = minibatchqueue(arrayDatastore(dataset_train, OutputType = "same"), ...
                                 MiniBatchSize = mini_batch_size, ...
                                 MiniBatchFcn = @DLModel.preprocess_mini_batch, ...
                                 MiniBatchFormat = ["SSCB", ""], ...
                                 PartialMiniBatch = "discard");

            epoch = 0;
            while epoch < epochs
                epoch = epoch + 1;

                shuffle(mbq);
                while hasdata(mbq)
                    obj.iteration = obj.iteration + 1;

                    [data, labels] = next(mbq);

                    [loss, gradients, state] = dlfeval(@DLModel.loss_function, obj.network, data, labels);
                    obj.network.State = state;

                    learn_rate = initial_learn_rate / (1 + decay * obj.iteration);
                    [obj.network, obj.velocity] = ...
                        sgdmupdate(obj.network, gradients, obj.velocity, learn_rate, momentum);
                end
            end
        end

        function combine_with_neighbors(obj, self_weight, neighbor_models)
            arguments% (Input)
                obj (1, 1) DLModel;
                self_weight (1, 1) {mustBeInteger, mustBeNonnegative};
                neighbor_models (:, 1) cell;  % cell<DLModel>
            end

            if height(neighbor_models) == 0
                return
            end

            obj.network.Learnables = DLModel.calc_value_step(self_weight, ...
                                                             obj.network.Learnables, ...
                                                             cellfun(@(it) it.network.Learnables, ...
                                                                     neighbor_models(:), ...
                                                                     UniformOutput = false));
            obj.velocity = DLModel.calc_value_step(self_weight, ...
                                                   obj.velocity, ...
                                                   cellfun(@(it) it.velocity, ...
                                                           neighbor_models(:), ...
                                                           UniformOutput = false));
        end


        function accuracy = evaluate_accuracy(obj, dataset_test, mini_batch_size)
            arguments% (Input)
                obj (1, 1) DLModel;
                dataset_test (:, 2) table {mustHaveColumns(dataset_test, ["inputs", "labels"])};
                mini_batch_size (1, 1) {mustBeInteger};
            end
            % arguments (Output)
            %     accuracy (1, 1) double;
            % end

            % Get predictions
            mbq = minibatchqueue(arrayDatastore(dataset_test, OutputType = "same"), ...
                                 1, ...
                                 MiniBatchSize = mini_batch_size, ...
                                 MiniBatchFcn = @DLModel.preprocess_mini_batch_predictors, ...
                                 MiniBatchFormat = "SSCB", ...
                                 PartialMiniBatch = "discard");

            predictions = categorical(zeros([height(dataset_test), 1]), obj.classes);
            i = 1;
            while hasdata(mbq)
                samples = next(mbq);

                scores = predict(obj.network, samples);
                labels = onehotdecode(scores, obj.classes, 1)';

                predictions(i:(i + height(labels) - 1)) = labels;
                i = i + height(labels);
            end

            % Calculate performance
            accuracy = mean(predictions == dataset_test.labels);
        end

        function flat_learnables = get_flat_learnables(obj)
            arguments% (Input)
                obj (1, 1) DLModel;
            end
            % arguments (Output)
            %     flat_learnables (:, 1) double;
            % end

            params = cellfun(@(it) reshape(it, [], 1), obj.network.Learnables.Value, UniformOutput = false);
            flat_learnables = extractdata(vertcat(params{:}));
        end
    end

    methods (Static, Access = private)
        function [loss, gradients, state] = loss_function(network, data, labels)
            % LOSS_FUNCTION  Calculates the [loss], [gradients], and [state] of this model
            % over [data] and [labels].

            arguments% (Input)
                network (1, 1) dlnetwork;
                data dlarray;
                labels dlarray;
            end
            % arguments (Output)
            %     loss (1, 1) dlarray;
            %     gradients (:, 3) table;
            %     state (:, 3) table;
            % end

            [predictions, state] = forward(network, data);
            loss = crossentropy(predictions, labels);
            gradients = dlgradient(loss, network.Learnables);
        end

        function X = preprocess_mini_batch_predictors(dataX)
            % PREPROCESS_MINI_BATCH_PREDICTORS  Pushes the [data] for a mini-batch into the
            % correct format for training.

            arguments% (Input)
                dataX (:, 1) cell;
            end
            % arguments (Output)
            %     X (28, 28, 1, :) double;
            % end

            X = cat(4, dataX{1:end});
        end

        function [X, T] = preprocess_mini_batch(dataX, dataT)
            % PREPROCESS_MINI_BATCH  Preprocesses [dataX] into [X] and [dataT] into [T] for
            % [train].

            arguments% (Input)
                dataX (:, 1) cell;
                dataT (:, 1) cell;
            end
            % arguments (Output)
            %     X (28, 28, 1, :) double;
            %     T (:, :) double;
            % end

            % Preprocess predictors
            X = DLModel.preprocess_mini_batch_predictors(dataX);

            % Extract label data from cell and concatenate
            T = cat(2, dataT{1:end});

            % One-hot encode labels
            T = onehotencode(T, 1);
        end

        function avg_table = calc_value_step(self_weight, self_table, other_tables)
            % CALC_VALUE_STEP  Calculates the weighted average along the Value column of
            % [self_table] (with weight [self_weight]) and each [other_tables] (excluding
            % empty tables) (each with weight 1).

            arguments% (Input)
                self_weight (1, 1) {mustBeInteger, mustBeNonnegative};
                self_table (:, 3) table;
                other_tables (:, 1) cell;
            end
            % arguments (Output)
            %     X (28, 28, 1, :) double;
            %     T (:, :) double;
            % end

            non_empty_mask = cellfun(@(it) any(it.Properties.VariableNames == "Value") && height(it.Value) > 0, ...
                                     other_tables);
            if sum(non_empty_mask) == 0
                avg_table = self_table;
                return;
            end

            total_weight = self_weight + sum(non_empty_mask);

            avg_table = self_table;
            for row_idx = 1:height(self_table)
                avg_row = avg_table.Value{row_idx} * self_weight;
                row_ndims = ndims(avg_table.Value{row_idx});

                other_rows = cellfun(@(it) it.Value{row_idx}, other_tables(non_empty_mask), UniformOutput = false);
                other_rows_sum = sum(cat(row_ndims + 1, other_rows{:}), row_ndims + 1);

                avg_table.Value{row_idx} = (avg_row + other_rows_sum) / total_weight;
            end
        end
    end
end
