classdef ExperimentConfig < Config
    % EXPERIMENTCONFIG  Configuration for running an [Experiment].

    properties
        % NODE_COUNT  Number of nodes in the network.
        node_count (1, 1) {mustBeInteger, mustBePositive} = 60;

        % LOG_SETTINGS  `true` if and only if settings should be displayed at the start
        % of the experiment.
        log_settings (1, 1) logical = false;
        % LOG_ROUNDS  `true` if and only if round-specific information should be logged.
        log_rounds (1, 1) logical = false;

        % DATASET_NAME  Name of the dataset to experiment on.
        %
        % If `"number"` is selected, a `NumberModel` is used, and all `train_*` options
        % are ignored.
        %
        % Must be one of: "mnist", "emnist", "number".
        dataset_name (1, 1) {mustBeMember(dataset_name, ["mnist", "emnist", "number"])} = "mnist";
        % DATASET_SAMPLES  Number of samples in the dataset to actually use, or a
        % negative number to use all samples in the dataset.
        dataset_samples (1, 1) {mustBeInteger} = -1;
        % DATASET_SPLIT_FRACTION  Numerator and denominator of fraction of samples to
        % use for training.
        %
        % For example, a value of `[4, 5]` assigns 80% of samples for training, and 20%
        % for testing.
        dataset_split_fraction (1, 2) {mustBeInteger, mustBeNonnegative} = [4, 5];
        % DATASET_IID  `true` if and only if data should be split i.i.d. over clients.
        dataset_iid (1, 1) logical = false;
        % DATASET_IID_DIRICHLET_ALPHA  The alpha parameter of the Dirichlet distribution
        % to generate non-i.i.d. distributions with.
        %
        % Ignored if [dataset_iid] is `false`.
        dataset_iid_dirichlet_alpha (1, 1) {mustBeFloat, mustBeNonnegative} = 1;
        % DATASET_IID_EMNIST_BY_WRITER  `true` if and only if each node corresponds to
        % one writer.
        %
        % Ignored if [dataset_name] is not "emnist", or if [dataset_iid] is `false`.
        dataset_iid_emnist_by_writer (1, 1) logical = false;

        % NETWORK_LAYOUT  Graph type to use for the network layout.
        %
        % Must be one of: "erdos-renyi", "watts-strogatz", "barabasi-albert",
        % "complete", "empty".
        network_layout (1, 1) {mustBeMember(network_layout, ["erdos-renyi", ...
                                                             "watts-strogatz", ...
                                                             "barabasi-albert", ...
                                                             "complete", ...
                                                             "empty"])} = "erdos-renyi";
        % NETWORK_ERDOS_RENYI_P  The `p` parameter for Erdős-Renyi graphs.
        network_erdos_renyi_p (1, 1) {mustBeFloat, mustBeInRange(network_erdos_renyi_p, 0, 1)} = 0.1;
        % NETWORK_WATTS_STROGATZ_K  The `k` parameter for Watts-Strogatz graphs.
        network_watts_strogatz_k (1, 1) {mustBeFloat, mustBeNonnegative} = 12;
        % NETWORK_WATTS_STROGATZ_P  The `p` parameter for Watts-Strogatz graphs.
        network_watts_strogatz_p (1, 1) {mustBeFloat, mustBeInRange(network_watts_strogatz_p, 0, 1)} = 0.5;
        % NETWORK_BARABASI_ALBERT_M  The `m` parameter for Barabási-Albert graphs.
        network_barabasi_albert_m (1, 1) {mustBeInteger, mustBePositive} = 12;
        % NETWORK_MIN_GIRTH  The girth to which the graph should be stretched (at least)
        % in post-processing.
        %
        % Set this number to -1 if all cycles should be removed, or to 0 to not remove
        % any cycles.
        network_min_girth (1, 1) {mustBeInteger} = 0;
        % NETWORK_REQUIRE_CONNECTED  `true` if and only if the created network must
        % consist of a single connected component.
        network_require_connected (1, 1) logical = true;
        % NETWORK_MAX_TRIALS  The maximum number of attempts to generate a network
        % before erroring.
        network_max_attempts (1, 1) {mustBeInteger, mustBePositive} = 250;

        % TRAIN_EPOCHS_PER_ROUND  The number of epochs to train the selected node in a
        % round of federated learning.
        train_epochs_per_round (1, 1) {mustBeInteger, mustBePositive} = 1;
        % TRAIN_MINI_BATCH_SIZE  Number of samples per mini-batch during training.
        %
        % This value should be adjusted based on the number of nodes, samples, and
        % i.i.d.-ness.
        train_mini_batch_size (1, 1) {mustBeInteger, mustBePositive} = 16;
        % TRAIN_INITIAL_LEARN_RATE  The initial learn rate of each epoch.
        train_initial_learn_rate (1, 1) {mustBeFloat, mustBeNonnegative} = 0.01;
        % TRAIN_DECAY  The learning decay per iteration.
        train_decay (1, 1) {mustBeFloat, mustBeNonnegative} = 0.01;
        % TRAIN_MOMENTUM  The learning momentum per epoch.
        %
        % A value of 0 means no momentum is used.
        train_momentum (1, 1) {mustBeFloat, mustBeNonnegative} = 0;

        % FL_SELF_WEIGHT  The weight of the node's own model when averaging with
        % neighbours.
        %
        % Ignored if `dataset_name` is `"number"`.
        fl_self_weight (1, 1) = 1;
        % FL_ROUNDS  Number of federated learning rounds to run, or a negative number to
        % not limit the experiment by the number of rounds.
        fl_rounds (1, 1) {mustBeInteger} = 50;
        % FL_CONVERGENCE_THRESHOLD  Largest difference between model parameters before
        % convergence is achieved, or a negative number to not limit the experiment by
        % convergence.
        fl_convergence_threshold (1, 1) {mustBeFloat} = -1;

        % METRICS_TEXT  `true` if and only if metrics should be displayed in the
        % console.
        metrics_text (1, 1) logical = false;
        % METRICS_GUI  `true` if and only if metrics should be displayed in a separate
        % window.
        %
        % Requires MATLAB R2022b.
        metrics_gui (1, 1) logical = false;
        % METRICS_GUI_BY_NODE  `true` if and only if the GUI should show metrics
        % individually per node (in addition to global metrics).
        %
        % Ignored if [metrics_gui] is `false`.
        metrics_gui_by_node (1, 1) logical = true;
        % METRICS_TRAINING_LOSS  `true` if and only if training loss should be tracked.
        metrics_training_loss (1, 1) logical = true;
        % METRICS_VALIDATION_ACCURACY  `true` if and only if validation accuracy should
        % be tracked.
        %
        % This incurs extra computational cost each round.
        metrics_validation_accuracy (1, 1) logical = true;
    end


    methods
        function obj = ExperimentConfig(args)
            arguments% (Input)
                args.?ExperimentConfig;
            end
            % arguments (Output)
            %     obj (1, 1) ExperimentConfig;
            % end

            obj = obj@Config(args);
        end


        function new_obj = set(obj, args)
            arguments% (Input)
                obj (1, 1) ExperimentConfig;
                args.?ExperimentConfig;
            end
            % arguments (Output)
            %     new_obj (1, 1) ExperimentConfig;
            % end

            new_obj = set@Config(obj, args);
        end

        function h = cache_id(obj)
            arguments% (Input)
                obj (1, 1) ExperimentConfig;
            end
            % arguments (Output)
            %     h (1, 1) {mustBeText};
            % end

            h = cache_id@Config(obj.set(log_settings = false, ...
                                        log_rounds = false, ...
                                        metrics_text = false, ...
                                        metrics_gui = false, ...
                                        metrics_gui_by_node = false));
        end
    end

    methods (Static)
        function partial_obj = partial(args)
            arguments% (Input)
                args.?ExperimentConfig;
            end
            % arguments (Output)
            %     partial_obj (1, :) cell;
            % end

            partial_obj = partial@Config(args);
        end
    end
end
