package jsat.classifiers.neuralnetwork;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.distributions.Normal;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.FunctionBase;
import jsat.math.decayrates.DecayRate;
import jsat.math.decayrates.ExponetialDecay;
import jsat.parameters.IntParameter;
import jsat.parameters.ObjectParameter;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

/* loaded from: input_file:jsat/classifiers/neuralnetwork/BackPropagationNet.class */
public class BackPropagationNet implements Classifier, Regressor, Parameterized {
    private static final long serialVersionUID = 335438198218313862L;
    private int inputSize;
    private int outputSize;
    private ActivationFunction f;
    private DecayRate learningRateDecay;
    private double momentum;
    private double weightDecay;
    private int epochs;
    private double initialLearningRate;
    private WeightInitialization weightInitialization;
    private double targetBump;
    private int batchSize;
    private int[] npl;
    private List<Matrix> Ws;
    private List<Vec> bs;
    private double targetMax;
    private double targetMin;
    private double targetMultiplier;
    public static final ActivationFunction logitActiv = new ActivationFunction() { // from class: jsat.classifiers.neuralnetwork.BackPropagationNet.1
        private static final long serialVersionUID = -5675881412853268432L;

        @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.ActivationFunction
        public double response(double d) {
            return 1.0d / (1.0d + Math.exp(-d));
        }

        @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.ActivationFunction
        public double min() {
            return 0.0d;
        }

        @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.ActivationFunction
        public double max() {
            return 1.0d;
        }

        @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.ActivationFunction
        public Function getD() {
            return BackPropagationNet.logitPrime;
        }

        public String toString() {
            return "Logit";
        }
    };
    private static final Function logitPrime = new FunctionBase() { // from class: jsat.classifiers.neuralnetwork.BackPropagationNet.2
        private static final long serialVersionUID = 7201403465671204173L;

        @Override // jsat.math.Function
        public double f(Vec vec) {
            double d = vec.get(0);
            return d * (1.0d - d);
        }
    };
    public static final ActivationFunction tanhActiv = new ActivationFunction() { // from class: jsat.classifiers.neuralnetwork.BackPropagationNet.3
        private static final long serialVersionUID = 5531922338473526216L;

        @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.ActivationFunction
        public double response(double d) {
            return Math.tanh(d);
        }

        @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.ActivationFunction
        public double min() {
            return -1.0d;
        }

        @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.ActivationFunction
        public double max() {
            return 1.0d;
        }

        @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.ActivationFunction
        public Function getD() {
            return BackPropagationNet.tanhPrime;
        }

        public String toString() {
            return "Tanh";
        }
    };
    private static final Function tanhPrime = new FunctionBase() { // from class: jsat.classifiers.neuralnetwork.BackPropagationNet.4
        private static final long serialVersionUID = -7271551720122166947L;

        @Override // jsat.math.Function
        public double f(Vec vec) {
            double d = vec.get(0);
            return 1.0d - (d * d);
        }
    };
    public static final ActivationFunction softsignActiv = new ActivationFunction() { // from class: jsat.classifiers.neuralnetwork.BackPropagationNet.5
        private static final long serialVersionUID = 1618447580574194519L;

        @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.ActivationFunction
        public double response(double d) {
            return d / (1.0d + Math.abs(d));
        }

        @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.ActivationFunction
        public double min() {
            return -1.0d;
        }

        @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.ActivationFunction
        public double max() {
            return 1.0d;
        }

        @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.ActivationFunction
        public Function getD() {
            return BackPropagationNet.softsignPrime;
        }

        public String toString() {
            return "Softsign";
        }
    };
    private static final Function softsignPrime = new FunctionBase() { // from class: jsat.classifiers.neuralnetwork.BackPropagationNet.6
        private static final long serialVersionUID = -6726314880590071199L;

        @Override // jsat.math.Function
        public double f(Vec vec) {
            double abs = 1.0d - Math.abs(vec.get(0));
            return abs * abs;
        }
    };

    /* loaded from: input_file:jsat/classifiers/neuralnetwork/BackPropagationNet$ActivationFunction.class */
    public static abstract class ActivationFunction implements Function {
        private static final long serialVersionUID = 8002040194215453918L;

        public abstract double response(double d);

        public abstract double min();

        public abstract double max();

        public abstract Function getD();

        @Override // jsat.math.Function
        public double f(double... dArr) {
            return response(dArr[0]);
        }

        @Override // jsat.math.Function
        public double f(Vec vec) {
            return response(vec.get(0));
        }
    }

    /* loaded from: input_file:jsat/classifiers/neuralnetwork/BackPropagationNet$WeightInitialization.class */
    public enum WeightInitialization {
        UNIFORM { // from class: jsat.classifiers.neuralnetwork.BackPropagationNet.WeightInitialization.1
            @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.WeightInitialization
            public double getWeight(int i, int i2, double d, Random random) {
                return (random.nextDouble() * 1.4d) - 0.7d;
            }
        },
        GUASSIAN { // from class: jsat.classifiers.neuralnetwork.BackPropagationNet.WeightInitialization.2
            @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.WeightInitialization
            public double getWeight(int i, int i2, double d, Random random) {
                return Normal.invcdf(random.nextDouble(), 0.0d, Math.pow(i, -0.5d));
            }
        },
        TANH_NORMALIZED_INITIALIZATION { // from class: jsat.classifiers.neuralnetwork.BackPropagationNet.WeightInitialization.3
            @Override // jsat.classifiers.neuralnetwork.BackPropagationNet.WeightInitialization
            public double getWeight(int i, int i2, double d, Random random) {
                double sqrt = Math.sqrt(6.0d / (i + i2));
                return ((random.nextDouble() * sqrt) * 2.0d) - sqrt;
            }
        };

        public abstract double getWeight(int i, int i2, double d, Random random);
    }

    public BackPropagationNet() {
        this(1024);
    }

    public BackPropagationNet(int... iArr) {
        this.f = softsignActiv;
        this.learningRateDecay = new ExponetialDecay();
        this.momentum = 0.1d;
        this.weightDecay = 0.0d;
        this.epochs = 1000;
        this.initialLearningRate = 0.2d;
        this.weightInitialization = WeightInitialization.TANH_NORMALIZED_INITIALIZATION;
        this.targetBump = 0.1d;
        this.batchSize = 10;
        if (iArr.length < 1) {
            throw new IllegalArgumentException("There must be at least one hidden layer");
        }
        this.npl = iArr;
    }

    protected BackPropagationNet(BackPropagationNet backPropagationNet) {
        this(Arrays.copyOf(backPropagationNet.npl, backPropagationNet.npl.length));
        this.inputSize = backPropagationNet.inputSize;
        this.outputSize = backPropagationNet.outputSize;
        this.f = backPropagationNet.f;
        this.momentum = backPropagationNet.momentum;
        this.weightDecay = backPropagationNet.weightDecay;
        this.epochs = backPropagationNet.epochs;
        this.initialLearningRate = backPropagationNet.initialLearningRate;
        this.learningRateDecay = backPropagationNet.learningRateDecay;
        this.weightInitialization = backPropagationNet.weightInitialization;
        this.targetBump = backPropagationNet.targetBump;
        this.targetMax = backPropagationNet.targetMax;
        this.targetMin = backPropagationNet.targetMin;
        this.targetMultiplier = backPropagationNet.targetMultiplier;
        this.batchSize = backPropagationNet.batchSize;
        if (backPropagationNet.Ws != null) {
            this.Ws = new ArrayList(backPropagationNet.Ws);
            for (int i = 0; i < this.Ws.size(); i++) {
                this.Ws.set(i, this.Ws.get(i).mo161clone());
            }
        }
        if (backPropagationNet.bs != null) {
            this.bs = new ArrayList(backPropagationNet.bs);
            for (int i2 = 0; i2 < this.bs.size(); i2++) {
                this.bs.set(i2, this.bs.get(i2).mo45clone());
            }
        }
    }

    private void trainNN(DataSet dataSet) {
        ArrayList arrayList = new ArrayList(this.batchSize);
        ArrayList arrayList2 = new ArrayList(this.batchSize);
        ArrayList arrayList3 = new ArrayList(this.batchSize);
        ArrayList arrayList4 = new ArrayList(this.Ws.size());
        ArrayList arrayList5 = new ArrayList(this.batchSize);
        new ArrayList(this.batchSize);
        for (int i = 0; i < this.batchSize; i++) {
            arrayList.add(new ArrayList(this.Ws.size()));
            arrayList2.add(new ArrayList(this.Ws.size()));
            arrayList3.add(new ArrayList(this.Ws.size()));
            for (Matrix matrix : this.Ws) {
                int rows = matrix.rows();
                ((List) arrayList.get(i)).add(new DenseVector(rows));
                ((List) arrayList2.get(i)).add(new DenseVector(rows));
                ((List) arrayList3.get(i)).add(new DenseVector(rows));
                if (i == 0) {
                    arrayList4.add(new DenseMatrix(matrix.rows(), matrix.cols()));
                }
            }
        }
        IntList intList = new IntList(dataSet.getSampleSize());
        ListUtils.addRange(intList, 0, dataSet.getSampleSize(), 1);
        double d = 1.0d / this.batchSize;
        for (int i2 = 0; i2 < this.epochs; i2++) {
            Collections.shuffle(intList);
            double rate = this.learningRateDecay.rate(i2, this.epochs, this.initialLearningRate);
            double d2 = 0.0d;
            int i3 = 0;
            while (true) {
                int i4 = i3;
                if (i4 < dataSet.getSampleSize()) {
                    if (dataSet.getSampleSize() - i4 >= this.batchSize) {
                        arrayList5.clear();
                        for (int i5 = 0; i5 < this.batchSize; i5++) {
                            int intValue = intList.get(i4 + i5).intValue();
                            Vec numericalValues = dataSet.getDataPoint(intValue).getNumericalValues();
                            arrayList5.add(numericalValues);
                            feedForward(numericalValues, (List) arrayList.get(i5), (List) arrayList2.get(i5));
                            d2 += computeOutputDelta(dataSet, intValue, (Vec) ((List) arrayList3.get(i5)).get(this.npl.length), (Vec) ((List) arrayList.get(i5)).get(this.npl.length), (Vec) ((List) arrayList2.get(i5)).get(this.npl.length));
                        }
                        for (int i6 = 0; i6 < this.batchSize; i6++) {
                            for (int size = this.Ws.size() - 2; size >= 0; size--) {
                                Vec vec = (Vec) ((List) arrayList3.get(i6)).get(size);
                                vec.zeroOut();
                                this.Ws.get(size + 1).transposeMultiply(1.0d, (Vec) ((List) arrayList3.get(i6)).get(size + 1), vec);
                                vec.mutablePairwiseMultiply((Vec) ((List) arrayList2.get(i6)).get(size));
                            }
                            for (int i7 = 1; i7 < this.Ws.size(); i7++) {
                                Matrix matrix2 = this.Ws.get(i7);
                                Vec vec2 = this.bs.get(i7);
                                matrix2.mutableSubtract(rate * this.weightDecay, matrix2);
                                if (this.momentum != 0.0d) {
                                    Matrix matrix3 = (Matrix) arrayList4.get(i7);
                                    matrix3.mutableMultiply(this.momentum);
                                    Matrix.OuterProductUpdate(matrix3, (Vec) ((List) arrayList3.get(i6)).get(i7), (Vec) ((List) arrayList.get(i6)).get(i7 - 1), (-rate) * d);
                                    matrix2.mutableAdd(matrix3);
                                } else {
                                    Matrix.OuterProductUpdate(matrix2, (Vec) ((List) arrayList3.get(i6)).get(i7), (Vec) ((List) arrayList.get(i6)).get(i7 - 1), (-rate) * d);
                                }
                                vec2.mutableAdd((-rate) * d, (Vec) ((List) arrayList3.get(i6)).get(i7));
                            }
                            Matrix matrix4 = this.Ws.get(0);
                            matrix4.mutableSubtract(rate * this.weightDecay, matrix4);
                            Vec vec3 = this.bs.get(0);
                            if (this.momentum != 0.0d) {
                                Matrix matrix5 = (Matrix) arrayList4.get(0);
                                matrix5.mutableMultiply(this.momentum);
                                Matrix.OuterProductUpdate(matrix5, (Vec) ((List) arrayList3.get(i6)).get(0), (Vec) arrayList5.get(i6), (-rate) * d);
                                matrix4.mutableAdd(matrix5);
                            } else {
                                Matrix.OuterProductUpdate(matrix4, (Vec) ((List) arrayList3.get(i6)).get(0), (Vec) arrayList5.get(i6), (-rate) * d);
                            }
                            vec3.mutableAdd((-rate) * d, (Vec) ((List) arrayList3.get(i6)).get(0));
                        }
                    }
                    i3 = i4 + this.batchSize;
                }
            }
        }
    }

    public void setMomentum(double d) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new ArithmeticException("Momentum must be non negative, not " + d);
        }
        this.momentum = d;
    }

    public double getMomentum() {
        return this.momentum;
    }

    public void setInitialLearningRate(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new ArithmeticException("Learning rate must be a positive cosntant, not " + d);
        }
        this.initialLearningRate = d;
    }

    public double getInitialLearningRate() {
        return this.initialLearningRate;
    }

    public void setLearningRateDecay(DecayRate decayRate) {
        this.learningRateDecay = decayRate;
    }

    public DecayRate getLearningRateDecay() {
        return this.learningRateDecay;
    }

    public void setEpochs(int i) {
        if (i < 1) {
            throw new ArithmeticException("number of training epochs must be positive, not " + i);
        }
        this.epochs = i;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public void setWeightDecay(double d) {
        if (d < 0.0d || d >= 1.0d || Double.isNaN(d)) {
            throw new ArithmeticException("Weight decay must be in [0,1), not " + d);
        }
        this.weightDecay = d;
    }

    public double getWeightDecay() {
        return this.weightDecay;
    }

    public void setWeightInitialization(WeightInitialization weightInitialization) {
        this.weightInitialization = weightInitialization;
    }

    public WeightInitialization getWeightInitialization() {
        return this.weightInitialization;
    }

    public void setBatchSize(int i) {
        this.batchSize = i;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setActivationFunction(ActivationFunction activationFunction) {
        this.f = activationFunction;
    }

    public ActivationFunction getActivationFunction() {
        return this.f;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(this.outputSize);
        Vec feedForward = feedForward(dataPoint.getNumericalValues());
        feedForward.mutableSubtract(this.f.min() + this.targetBump);
        for (int i = 0; i < feedForward.length(); i++) {
            categoricalResults.setProb(i, Math.max(feedForward.get(i), 0.0d));
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        return (((feedForward(dataPoint.getNumericalValues()).get(0) - this.f.min()) - this.targetBump) / this.targetMultiplier) + this.targetMin;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        this.inputSize = classificationDataSet.getNumNumericalVars();
        this.outputSize = classificationDataSet.getClassSize();
        setUp(new Random());
        trainNN(classificationDataSet);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        train(regressionDataSet);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        this.targetMax = Double.NEGATIVE_INFINITY;
        this.targetMin = Double.POSITIVE_INFINITY;
        for (int i = 0; i < regressionDataSet.getSampleSize(); i++) {
            double targetValue = regressionDataSet.getTargetValue(i);
            this.targetMax = Math.max(this.targetMax, targetValue);
            this.targetMin = Math.min(this.targetMin, targetValue);
        }
        this.targetMultiplier = ((this.f.max() - this.targetBump) - (this.f.min() + this.targetBump)) / (this.targetMax - this.targetMin);
        this.inputSize = regressionDataSet.getNumNumericalVars();
        this.outputSize = 1;
        setUp(new Random());
        trainNN(regressionDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.regression.Regressor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public BackPropagationNet mo234clone() {
        return new BackPropagationNet(this);
    }

    private void setUp(Random random) {
        this.Ws = new ArrayList(this.npl.length);
        this.bs = new ArrayList(this.npl.length);
        DenseMatrix denseMatrix = new DenseMatrix(this.npl[0], this.inputSize);
        DenseVector denseVector = new DenseVector(denseMatrix.rows());
        initializeWeights(denseMatrix, random);
        initializeWeights(denseVector, denseMatrix.cols(), random);
        this.Ws.add(denseMatrix);
        this.bs.add(denseVector);
        for (int i = 1; i < this.npl.length; i++) {
            DenseMatrix denseMatrix2 = new DenseMatrix(this.npl[i], this.npl[i - 1]);
            DenseVector denseVector2 = new DenseVector(denseMatrix2.rows());
            initializeWeights(denseMatrix2, random);
            initializeWeights(denseVector2, denseMatrix2.cols(), random);
            this.Ws.add(denseMatrix2);
            this.bs.add(denseVector2);
        }
        DenseMatrix denseMatrix3 = new DenseMatrix(this.outputSize, this.npl[this.npl.length - 1]);
        DenseVector denseVector3 = new DenseVector(denseMatrix3.rows());
        initializeWeights(denseMatrix3, random);
        initializeWeights(denseVector3, denseMatrix3.cols(), random);
        this.Ws.add(denseMatrix3);
        this.bs.add(denseVector3);
    }

    private double computeOutputDelta(DataSet dataSet, int i, Vec vec, Vec vec2, Vec vec3) {
        double d = 0.0d;
        if (dataSet instanceof ClassificationDataSet) {
            int dataPointCategory = ((ClassificationDataSet) dataSet).getDataPointCategory(i);
            for (int i2 = 0; i2 < this.outputSize; i2++) {
                if (i2 == dataPointCategory) {
                    vec.set(i2, this.f.max() - this.targetBump);
                } else {
                    vec.set(i2, this.f.min() + this.targetBump);
                }
            }
            for (int i3 = 0; i3 < vec.length(); i3++) {
                double d2 = vec.get(i3);
                d += Math.pow(d2 - vec2.get(i3), 2.0d);
                vec.set(i3, (-(d2 - vec2.get(i3))) * vec3.get(i3));
            }
        } else {
            if (!(dataSet instanceof RegressionDataSet)) {
                throw new RuntimeException("BUG: please report");
            }
            double min = this.f.min() + this.targetBump + (this.targetMultiplier * (((RegressionDataSet) dataSet).getTargetValue(i) - this.targetMin));
            d = 0.0d + Math.pow(min - vec2.get(0), 2.0d);
            vec.set(0, (-(min - vec2.get(0))) * vec3.get(0));
        }
        return d;
    }

    private void feedForward(Vec vec, List<Vec> list, List<Vec> list2) {
        Vec vec2 = vec;
        for (int i = 0; i < this.Ws.size(); i++) {
            Matrix matrix = this.Ws.get(i);
            Vec vec3 = this.bs.get(i);
            Vec vec4 = list.get(i);
            vec4.zeroOut();
            matrix.multiply(vec2, 1.0d, vec4);
            vec4.mutableAdd(vec3);
            vec4.applyFunction(this.f);
            Vec vec5 = list2.get(i);
            vec4.copyTo(vec5);
            vec5.applyFunction(this.f.getD());
            vec2 = vec4;
        }
    }

    private Vec feedForward(Vec vec) {
        Vec vec2 = vec;
        for (int i = 0; i < this.Ws.size(); i++) {
            Matrix matrix = this.Ws.get(i);
            Vec vec3 = this.bs.get(i);
            Vec multiply = matrix.multiply(vec2);
            multiply.mutableAdd(vec3);
            multiply.applyFunction(this.f);
            vec2 = multiply;
        }
        return vec2;
    }

    private void initializeWeights(Matrix matrix, Random random) {
        for (int i = 0; i < matrix.rows(); i++) {
            for (int i2 = 0; i2 < matrix.cols(); i2++) {
                matrix.set(i, i2, this.weightInitialization.getWeight(matrix.cols(), matrix.rows(), this.initialLearningRate, random));
            }
        }
    }

    private void initializeWeights(Vec vec, int i, Random random) {
        for (int i2 = 0; i2 < vec.length(); i2++) {
            vec.set(i2, this.weightInitialization.getWeight(i, vec.length(), this.initialLearningRate, random));
        }
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        ArrayList arrayList = new ArrayList(Parameter.getParamsFromMethods(this));
        for (int i = 0; i < this.npl.length; i++) {
            final int i2 = i;
            if (this.npl[i2] < 1) {
                throw new ArithmeticException("There must be a poistive number of hidden neurons in each layer");
            }
            arrayList.add(new IntParameter() { // from class: jsat.classifiers.neuralnetwork.BackPropagationNet.7
                private static final long serialVersionUID = -827784019950722754L;

                @Override // jsat.parameters.IntParameter
                public int getValue() {
                    return BackPropagationNet.this.npl[i2];
                }

                @Override // jsat.parameters.IntParameter
                public boolean setValue(int i3) {
                    if (i3 <= 0) {
                        return false;
                    }
                    BackPropagationNet.this.npl[i2] = i3;
                    return true;
                }

                @Override // jsat.parameters.Parameter
                public String getASCIIName() {
                    return "Neurons for Hidden Layer " + i2;
                }
            });
        }
        arrayList.add(new ObjectParameter<ActivationFunction>() { // from class: jsat.classifiers.neuralnetwork.BackPropagationNet.8
            private static final long serialVersionUID = 6871130865935243583L;

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // jsat.parameters.ObjectParameter
            public ActivationFunction getObject() {
                return BackPropagationNet.this.getActivationFunction();
            }

            @Override // jsat.parameters.ObjectParameter
            public boolean setObject(ActivationFunction activationFunction) {
                BackPropagationNet.this.setActivationFunction(activationFunction);
                return true;
            }

            @Override // jsat.parameters.ObjectParameter
            public List<ActivationFunction> parameterOptions() {
                return Arrays.asList(BackPropagationNet.logitActiv, BackPropagationNet.tanhActiv, BackPropagationNet.softsignActiv);
            }

            @Override // jsat.parameters.Parameter
            public String getASCIIName() {
                return "Activation Function";
            }
        });
        return Collections.unmodifiableList(arrayList);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
