package jsat.classifiers.neuralnetwork;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.neuralnetwork.activations.ActivationLayer;
import jsat.classifiers.neuralnetwork.initializers.BiastInitializer;
import jsat.classifiers.neuralnetwork.initializers.WeightInitializer;
import jsat.classifiers.neuralnetwork.regularizers.Max2NormRegularizer;
import jsat.classifiers.neuralnetwork.regularizers.WeightRegularizer;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.math.decayrates.DecayRate;
import jsat.math.decayrates.NoDecay;
import jsat.math.optimization.stochastic.GradientUpdater;
import jsat.math.optimization.stochastic.SimpleSGD;
import jsat.utils.SystemInfo;
import jsat.utils.random.XOR96;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:jsat/classifiers/neuralnetwork/SGDNetworkTrainer.class */
public class SGDNetworkTrainer implements Serializable {
    private static final long serialVersionUID = 5753653181230693131L;
    private int[] layerSizes;
    private double eta;
    private double p_i;
    private int p_i_intThresh;
    private double p_o;
    private int p_o_intThresh;
    private GradientUpdater updater;
    private WeightRegularizer regularizer;
    private WeightInitializer weightInit;
    private BiastInitializer biasInit;
    private List<Matrix> W;
    private List<Matrix> W_deltas;
    private List<List<GradientUpdater>> W_updaters;
    private List<Vec> B;
    private List<Vec> B_deltas;
    private List<GradientUpdater> B_updaters;
    private List<ActivationLayer> layersActivation;
    private DecayRate etaDecay;
    private int time;
    private Matrix[] activations;
    private Matrix[] unactivated;
    private Matrix[] deltas;
    static final /* synthetic */ boolean $assertionsDisabled;

    public SGDNetworkTrainer() {
        this.updater = new SimpleSGD();
        this.regularizer = new Max2NormRegularizer(15.0d);
        this.etaDecay = new NoDecay();
        setDropoutInput(0.2d);
        setDropoutHidden(0.5d);
    }

    public SGDNetworkTrainer(SGDNetworkTrainer sGDNetworkTrainer) {
        this.updater = new SimpleSGD();
        this.regularizer = new Max2NormRegularizer(15.0d);
        this.etaDecay = new NoDecay();
        this.layerSizes = Arrays.copyOf(sGDNetworkTrainer.layerSizes, sGDNetworkTrainer.layerSizes.length);
        this.eta = sGDNetworkTrainer.eta;
        this.weightInit = sGDNetworkTrainer.weightInit.m88clone();
        this.biasInit = sGDNetworkTrainer.biasInit.m88clone();
        this.regularizer = sGDNetworkTrainer.regularizer.m89clone();
        this.updater = sGDNetworkTrainer.updater.m233clone();
        setDropoutInput(sGDNetworkTrainer.getDropoutInput());
        setDropoutHidden(sGDNetworkTrainer.getDropoutHidden());
        if (sGDNetworkTrainer.W != null) {
            this.W = new ArrayList();
            Iterator<Matrix> it = sGDNetworkTrainer.W.iterator();
            while (it.hasNext()) {
                this.W.add(it.next().mo161clone());
            }
            this.B = new ArrayList();
            Iterator<Vec> it2 = sGDNetworkTrainer.B.iterator();
            while (it2.hasNext()) {
                this.B.add(it2.next().mo45clone());
            }
        }
        if (sGDNetworkTrainer.W_deltas != null) {
            this.W_deltas = new ArrayList();
            Iterator<Matrix> it3 = sGDNetworkTrainer.W_deltas.iterator();
            while (it3.hasNext()) {
                this.W_deltas.add(it3.next().mo161clone());
            }
            this.B_deltas = new ArrayList();
            Iterator<Vec> it4 = sGDNetworkTrainer.B_deltas.iterator();
            while (it4.hasNext()) {
                this.B_deltas.add(it4.next().mo45clone());
            }
        }
        if (sGDNetworkTrainer.W_updaters != null) {
            this.W_updaters = new ArrayList();
            for (List<GradientUpdater> list : sGDNetworkTrainer.W_updaters) {
                ArrayList arrayList = new ArrayList(list.size());
                this.W_updaters.add(arrayList);
                Iterator<GradientUpdater> it5 = list.iterator();
                while (it5.hasNext()) {
                    arrayList.add(it5.next().m233clone());
                }
            }
            this.B_updaters = new ArrayList(sGDNetworkTrainer.B_updaters);
            Iterator<GradientUpdater> it6 = sGDNetworkTrainer.B_updaters.iterator();
            while (it6.hasNext()) {
                this.B_updaters.add(it6.next().m233clone());
            }
        }
        this.layersActivation = new ArrayList(sGDNetworkTrainer.layersActivation.size());
        Iterator<ActivationLayer> it7 = sGDNetworkTrainer.layersActivation.iterator();
        while (it7.hasNext()) {
            this.layersActivation.add(it7.next().m84clone());
        }
    }

    public void setDropoutInput(double d) {
        if (d < 0.0d || d >= 1.0d || Double.isNaN(d)) {
            throw new IllegalArgumentException("Dropout probability must be in [0,1) not " + d);
        }
        this.p_i = d;
        this.p_i_intThresh = (int) ((4.294967295E9d * this.p_i) - 2.147483648E9d);
    }

    public double getDropoutInput() {
        return this.p_i;
    }

    public void setDropoutHidden(double d) {
        if (d < 0.0d || d >= 1.0d || Double.isNaN(d)) {
            throw new IllegalArgumentException("Dropout probability must be in [0,1) not " + d);
        }
        this.p_o = d;
        this.p_o_intThresh = (int) ((4.294967295E9d * this.p_o) - 2.147483648E9d);
    }

    public double getDropoutHidden() {
        return this.p_o;
    }

    public void setEtaDecay(DecayRate decayRate) {
        this.etaDecay = decayRate;
    }

    public DecayRate getEtaDecay() {
        return this.etaDecay;
    }

    public void setEta(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("eta must be a positive constant, not " + d);
        }
        this.eta = d;
    }

    public double getEta() {
        return this.eta;
    }

    public void setRegularizer(WeightRegularizer weightRegularizer) {
        this.regularizer = weightRegularizer;
    }

    public WeightRegularizer getRegularizer() {
        return this.regularizer;
    }

    public void setLayerSizes(int... iArr) {
        this.layerSizes = iArr;
    }

    public int[] getLayerSizes() {
        return this.layerSizes;
    }

    public void setLayersActivation(List<ActivationLayer> list) {
        this.layersActivation = list;
    }

    public void setGradientUpdater(GradientUpdater gradientUpdater) {
        this.updater = gradientUpdater;
    }

    public GradientUpdater getGradientUpdater() {
        return this.updater;
    }

    public void setWeightInit(WeightInitializer weightInitializer) {
        this.weightInit = weightInitializer;
    }

    public WeightInitializer getWeightInit() {
        return this.weightInit;
    }

    public void setBiasInit(BiastInitializer biastInitializer) {
        this.biasInit = biastInitializer;
    }

    public BiastInitializer getBiasInit() {
        return this.biasInit;
    }

    public void setup() {
        if (!$assertionsDisabled && this.layersActivation.size() != this.layerSizes.length - 1) {
            throw new AssertionError();
        }
        this.W = new ArrayList(this.layersActivation.size());
        this.B = new ArrayList(this.layersActivation.size());
        XOR96 xor96 = new XOR96();
        for (int i = 1; i < this.layerSizes.length; i++) {
            this.W.add(new DenseMatrix(this.layerSizes[i], this.layerSizes[i - 1]));
            this.weightInit.init(this.W.get(this.W.size() - 1), xor96);
            this.B.add(new DenseVector(this.layerSizes[i]));
            this.biasInit.init(this.B.get(this.B.size() - 1), this.layerSizes[i - 1], xor96);
        }
        this.time = 0;
        prepareForUpdating();
    }

    private void prepareForUpdating() {
        this.W_deltas = new ArrayList(this.layersActivation.size());
        this.W_updaters = new ArrayList(this.layersActivation.size());
        this.B_deltas = new ArrayList(this.layersActivation.size());
        this.B_updaters = new ArrayList(this.layersActivation.size());
        for (int i = 1; i < this.layerSizes.length; i++) {
            this.W_deltas.add(new DenseMatrix(this.layerSizes[i], this.layerSizes[i - 1]));
            this.B_deltas.add(new DenseVector(this.layerSizes[i]));
            ArrayList arrayList = new ArrayList(this.layerSizes[i]);
            for (int i2 = 0; i2 < this.layerSizes[i]; i2++) {
                GradientUpdater m233clone = this.updater.m233clone();
                m233clone.setup(this.layerSizes[i - 1]);
                arrayList.add(m233clone);
            }
            this.W_updaters.add(arrayList);
            this.B_updaters.add(this.updater.m233clone());
            this.B_updaters.get(this.B_updaters.size() - 1).setup(this.layerSizes[i]);
        }
        this.activations = new Matrix[this.layersActivation.size()];
        this.unactivated = new Matrix[this.layersActivation.size()];
        this.deltas = new Matrix[this.layersActivation.size()];
    }

    public void finishUpdating() {
        this.W_deltas = null;
        this.W_updaters = null;
        this.B_deltas = null;
        this.B_updaters = null;
        this.deltas = null;
        this.unactivated = null;
        this.activations = null;
        this.W.get(0).mutableMultiply(1.0d - this.p_i);
        this.B.get(0).mutableMultiply(1.0d - this.p_i);
        for (int i = 1; i < this.W.size(); i++) {
            this.W.get(i).mutableMultiply(1.0d - this.p_o);
            this.B.get(i).mutableMultiply(1.0d - this.p_o);
        }
    }

    public double updateMiniBatch(List<Vec> list, List<Vec> list2) {
        return updateMiniBatch(list, list2, null);
    }

    public double updateMiniBatch(List<Vec> list, List<Vec> list2, ExecutorService executorService) {
        XORWOW xorwow = new XORWOW();
        Iterator<Matrix> it = this.W_deltas.iterator();
        while (it.hasNext()) {
            it.next().zeroOut();
        }
        Iterator<Vec> it2 = this.B_deltas.iterator();
        while (it2.hasNext()) {
            it2.next().zeroOut();
        }
        for (int i = 0; i < this.layersActivation.size(); i++) {
            if (this.activations[i] == null || this.activations[i].cols() != list.size()) {
                this.activations[i] = new DenseMatrix(this.layerSizes[i + 1], list.size());
            }
            if (this.unactivated[i] == null || this.unactivated[i].cols() != list.size()) {
                this.unactivated[i] = new DenseMatrix(this.layerSizes[i + 1], list.size());
            }
            if (this.deltas[i] == null || this.deltas[i].cols() != list.size()) {
                this.deltas[i] = new DenseMatrix(this.layerSizes[i + 1], list.size());
            }
        }
        DenseMatrix denseMatrix = new DenseMatrix(this.layerSizes[0], list.size());
        for (int i2 = 0; i2 < list.size(); i2++) {
            list.get(i2).copyTo(denseMatrix.getColumnView(i2));
        }
        if (this.p_i > 0.0d) {
            applyDropout(denseMatrix, this.p_i_intThresh, xorwow, executorService);
        }
        feedforward(denseMatrix, this.activations, this.unactivated, executorService, xorwow);
        double backpropagateError = backpropagateError(this.deltas, this.activations, list, list2, 0.0d, executorService, this.unactivated);
        accumulateUpdates(denseMatrix, this.activations, this.deltas, executorService, list);
        DecayRate decayRate = this.etaDecay;
        int i3 = this.time;
        this.time = i3 + 1;
        double rate = decayRate.rate(i3, this.eta);
        if (executorService == null) {
            applyGradient(rate);
        } else {
            applyGradient(rate, executorService);
        }
        return backpropagateError;
    }

    private void feedforward(Matrix matrix, Matrix[] matrixArr, Matrix[] matrixArr2, ExecutorService executorService, Random random) {
        int i = 0;
        while (i < this.layersActivation.size()) {
            Matrix matrix2 = i == 0 ? matrix : matrixArr[i - 1];
            Matrix matrix3 = matrixArr[i];
            final Matrix matrix4 = matrixArr2[i];
            matrix4.zeroOut();
            if (executorService == null) {
                this.W.get(i).multiply(matrix2, matrix4);
            } else {
                this.W.get(i).multiply(matrix2, matrix4, executorService);
            }
            final Vec vec = this.B.get(i);
            if (executorService == null) {
                for (int i2 = 0; i2 < matrix4.rows(); i2++) {
                    double d = vec.get(i2);
                    for (int i3 = 0; i3 < matrix4.cols(); i3++) {
                        matrix4.increment(i2, i3, d);
                    }
                }
            } else {
                final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
                for (int i4 = 0; i4 < SystemInfo.LogicalCores; i4++) {
                    final int i5 = i4;
                    executorService.submit(new Runnable() { // from class: jsat.classifiers.neuralnetwork.SGDNetworkTrainer.1
                        @Override // java.lang.Runnable
                        public void run() {
                            int i6 = i5;
                            while (true) {
                                int i7 = i6;
                                if (i7 >= matrix4.rows()) {
                                    countDownLatch.countDown();
                                    return;
                                }
                                double d2 = vec.get(i7);
                                for (int i8 = 0; i8 < matrix4.cols(); i8++) {
                                    matrix4.increment(i7, i8, d2);
                                }
                                i6 = i7 + SystemInfo.LogicalCores;
                            }
                        }
                    });
                }
                try {
                    countDownLatch.await();
                } catch (InterruptedException e) {
                    Logger.getLogger(SGDNetworkTrainer.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
                }
            }
            if (this.p_o > 0.0d && i != this.layersActivation.size() - 1) {
                applyDropout(matrix4, this.p_o_intThresh, random, executorService);
            }
            this.layersActivation.get(i).activate(matrix4, matrix3, false);
            i++;
        }
    }

    public Vec feedfoward(Vec vec) {
        Vec vec2 = vec;
        for (int i = 0; i < this.layersActivation.size(); i++) {
            DenseVector denseVector = new DenseVector(this.layerSizes[i + 1]);
            denseVector.zeroOut();
            this.W.get(i).multiply(vec2, 1.0d, denseVector);
            denseVector.mutableAdd(this.B.get(i));
            this.layersActivation.get(i).activate(denseVector, denseVector);
            vec2 = denseVector;
        }
        return vec2;
    }

    private double backpropagateError(Matrix[] matrixArr, Matrix[] matrixArr2, List<Vec> list, List<Vec> list2, double d, ExecutorService executorService, Matrix[] matrixArr3) {
        for (int size = this.layersActivation.size() - 1; size >= 0; size--) {
            Matrix matrix = matrixArr[size];
            if (size == this.layersActivation.size() - 1) {
                matrixArr2[size].copyTo(matrix);
                for (int i = 0; i < list.size(); i++) {
                    matrix.getColumnView(i).mutableSubtract(list2.get(i));
                    d += matrix.getColumnView(i).pNorm(2.0d);
                }
            } else {
                matrix.zeroOut();
                if (executorService == null) {
                    this.W.get(size + 1).transposeMultiply(matrixArr[size + 1], matrix);
                } else {
                    this.W.get(size + 1).transposeMultiply(matrixArr[size + 1], matrix, executorService);
                }
                this.layersActivation.get(size).backprop(matrixArr3[size], matrixArr2[size], matrix, matrix, false);
            }
        }
        return d;
    }

    private void accumulateUpdates(Matrix matrix, Matrix[] matrixArr, Matrix[] matrixArr2, ExecutorService executorService, List<Vec> list) {
        final double size = 1.0d / list.size();
        int i = 0;
        while (i < this.layersActivation.size()) {
            Matrix matrix2 = i == 0 ? matrix : matrixArr[i - 1];
            final Matrix matrix3 = matrixArr2[i];
            if (executorService == null) {
                matrix3.multiplyTranspose(matrix2, this.W_deltas.get(i));
            } else {
                matrix3.multiplyTranspose(matrix2, this.W_deltas.get(i), executorService);
            }
            this.W_deltas.get(i).mutableMultiply(size);
            final Vec vec = this.B_deltas.get(i);
            if (executorService == null) {
                for (int i2 = 0; i2 < matrix3.rows(); i2++) {
                    double d = 0.0d;
                    for (int i3 = 0; i3 < matrix3.cols(); i3++) {
                        d += matrix3.get(i2, i3);
                    }
                    vec.increment(i2, d * size);
                }
            } else {
                final CountDownLatch countDownLatch = new CountDownLatch(Math.min(SystemInfo.LogicalCores, matrix3.rows()));
                for (int i4 = 0; i4 < SystemInfo.LogicalCores; i4++) {
                    final int i5 = i4;
                    executorService.submit(new Runnable() { // from class: jsat.classifiers.neuralnetwork.SGDNetworkTrainer.2
                        @Override // java.lang.Runnable
                        public void run() {
                            int i6 = i5;
                            while (true) {
                                int i7 = i6;
                                if (i7 >= matrix3.rows()) {
                                    countDownLatch.countDown();
                                    return;
                                }
                                double d2 = 0.0d;
                                for (int i8 = 0; i8 < matrix3.cols(); i8++) {
                                    d2 += matrix3.get(i7, i8);
                                }
                                vec.increment(i7, d2 * size);
                                i6 = i7 + SystemInfo.LogicalCores;
                            }
                        }
                    });
                }
                try {
                    countDownLatch.await();
                } catch (InterruptedException e) {
                    Logger.getLogger(SGDNetworkTrainer.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
                }
            }
            i++;
        }
    }

    private void applyGradient(double d) {
        for (int i = 0; i < this.layersActivation.size(); i++) {
            this.B_updaters.get(i).update(this.B.get(i), this.B_deltas.get(i), d);
            Matrix matrix = this.W.get(i);
            Matrix matrix2 = this.W_deltas.get(i);
            for (int i2 = 0; i2 < matrix.rows(); i2++) {
                this.W_updaters.get(i).get(i2).update(matrix.getRowView(i2), matrix2.getRowView(i2), d);
            }
            this.regularizer.applyRegularization(matrix, this.B.get(i));
        }
    }

    private void applyGradient(final double d, ExecutorService executorService) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.layersActivation.size(); i++) {
            this.B_updaters.get(i).update(this.B.get(i), this.B_deltas.get(i), d);
            final Matrix matrix = this.W.get(i);
            final Matrix matrix2 = this.W_deltas.get(i);
            final int i2 = i;
            for (int i3 = 0; i3 < matrix.rows(); i3++) {
                final int i4 = i3;
                arrayList.add(executorService.submit(new Runnable() { // from class: jsat.classifiers.neuralnetwork.SGDNetworkTrainer.3
                    @Override // java.lang.Runnable
                    public void run() {
                        Vec rowView = matrix.getRowView(i4);
                        ((GradientUpdater) ((List) SGDNetworkTrainer.this.W_updaters.get(i2)).get(i4)).update(rowView, matrix2.getRowView(i4), d);
                        ((Vec) SGDNetworkTrainer.this.B.get(i2)).set(i4, SGDNetworkTrainer.this.regularizer.applyRegularizationToRow(rowView, ((Vec) SGDNetworkTrainer.this.B.get(i2)).get(i4)));
                    }
                }));
            }
        }
        try {
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                ((Future) it.next()).get();
            }
        } catch (InterruptedException e) {
        } catch (ExecutionException e2) {
        }
    }

    private static void applyDropout(final Matrix matrix, final int i, final Random random, ExecutorService executorService) {
        if (executorService == null) {
            for (int i2 = 0; i2 < matrix.rows(); i2++) {
                for (int i3 = 0; i3 < matrix.cols(); i3++) {
                    if (random.nextInt() < i) {
                        matrix.set(i2, i3, 0.0d);
                    }
                }
            }
            return;
        }
        final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
        for (int i4 = 0; i4 < SystemInfo.LogicalCores; i4++) {
            final int i5 = i4;
            executorService.submit(new Runnable() { // from class: jsat.classifiers.neuralnetwork.SGDNetworkTrainer.4
                @Override // java.lang.Runnable
                public void run() {
                    int i6 = i5;
                    while (true) {
                        int i7 = i6;
                        if (i7 >= matrix.rows()) {
                            countDownLatch.countDown();
                            return;
                        }
                        for (int i8 = 0; i8 < matrix.cols(); i8++) {
                            if (random.nextInt() < i) {
                                matrix.set(i7, i8, 0.0d);
                            }
                        }
                        i6 = i7 + SystemInfo.LogicalCores;
                    }
                }
            });
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            Logger.getLogger(SGDNetworkTrainer.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public SGDNetworkTrainer m77clone() {
        return new SGDNetworkTrainer(this);
    }

    static {
        $assertionsDisabled = !SGDNetworkTrainer.class.desiredAssertionStatus();
    }
}
