/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.ann;

import org.apache.spark.ml.ann.ANNGradient;
import org.apache.spark.ml.ann.ANNUpdater;
import org.apache.spark.ml.ann.DataStacker;
import org.apache.spark.ml.ann.Topology;
import org.apache.spark.ml.ann.TopologyModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.optimization.Gradient;
import org.apache.spark.mllib.optimization.GradientDescent;
import org.apache.spark.mllib.optimization.LBFGS;
import org.apache.spark.mllib.optimization.Optimizer;
import org.apache.spark.mllib.optimization.Updater;
import org.apache.spark.rdd.RDD;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;

@ScalaSignature(bytes="\u0006\u0001\u0005Ud!B\u0001\u0003\u0001\u0011a!A\u0005$fK\u00124uN]<be\u0012$&/Y5oKJT!a\u0001\u0003\u0002\u0007\u0005tgN\u0003\u0002\u0006\r\u0005\u0011Q\u000e\u001c\u0006\u0003\u000f!\tQa\u001d9be.T!!\u0003\u0006\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005Y\u0011aA8sON\u0019\u0001!D\n\u0011\u00059\tR\"A\b\u000b\u0003A\tQa]2bY\u0006L!AE\b\u0003\r\u0005s\u0017PU3g!\tqA#\u0003\u0002\u0016\u001f\ta1+\u001a:jC2L'0\u00192mK\"Aq\u0003\u0001B\u0001B\u0003%\u0011$\u0001\u0005u_B|Gn\\4z\u0007\u0001\u0001\"AG\u000e\u000e\u0003\tI!\u0001\b\u0002\u0003\u0011Q{\u0007o\u001c7pOfD\u0001B\b\u0001\u0003\u0006\u0004%\taH\u0001\nS:\u0004X\u000f^*ju\u0016,\u0012\u0001\t\t\u0003\u001d\u0005J!AI\b\u0003\u0007%sG\u000f\u0003\u0005%\u0001\t\u0005\t\u0015!\u0003!\u0003)Ig\u000e];u'&TX\r\t\u0005\tM\u0001\u0011)\u0019!C\u0001?\u0005Qq.\u001e;qkR\u001c\u0016N_3\t\u0011!\u0002!\u0011!Q\u0001\n\u0001\n1b\\;uaV$8+\u001b>fA!)!\u0006\u0001C\u0001W\u00051A(\u001b8jiz\"B\u0001L\u0017/_A\u0011!\u0004\u0001\u0005\u0006/%\u0002\r!\u0007\u0005\u0006=%\u0002\r\u0001\t\u0005\u0006M%\u0002\r\u0001\t\u0005\bc\u0001\u0001\r\u0011\"\u00033\u0003!yv/Z5hQR\u001cX#A\u001a\u0011\u0005QJT\"A\u001b\u000b\u0005Y:\u0014A\u00027j]\u0006dwM\u0003\u00029\r\u0005)Q\u000e\u001c7jE&\u0011!(\u000e\u0002\u0007-\u0016\u001cGo\u001c:\t\u000fq\u0002\u0001\u0019!C\u0005{\u0005aql^3jO\"$8o\u0018\u0013fcR\u0011a(\u0011\t\u0003\u001d}J!\u0001Q\b\u0003\tUs\u0017\u000e\u001e\u0005\b\u0005n\n\t\u00111\u00014\u0003\rAH%\r\u0005\u0007\t\u0002\u0001\u000b\u0015B\u001a\u0002\u0013};X-[4iiN\u0004\u0003b\u0002$\u0001\u0001\u0004%IaH\u0001\u000b?N$\u0018mY6TSj,\u0007b\u0002%\u0001\u0001\u0004%I!S\u0001\u000f?N$\u0018mY6TSj,w\fJ3r)\tq$\nC\u0004C\u000f\u0006\u0005\t\u0019\u0001\u0011\t\r1\u0003\u0001\u0015)\u0003!\u0003-y6\u000f^1dWNK'0\u001a\u0011\t\u000f9\u0003\u0001\u0019!C\u0005\u001f\u0006YA-\u0019;b'R\f7m[3s+\u0005\u0001\u0006C\u0001\u000eR\u0013\t\u0011&AA\u0006ECR\f7\u000b^1dW\u0016\u0014\bb\u0002+\u0001\u0001\u0004%I!V\u0001\u0010I\u0006$\u0018m\u0015;bG.,'o\u0018\u0013fcR\u0011aH\u0016\u0005\b\u0005N\u000b\t\u00111\u0001Q\u0011\u0019A\u0006\u0001)Q\u0005!\u0006aA-\u0019;b'R\f7m[3sA!9!\f\u0001a\u0001\n\u0013Y\u0016!C0he\u0006$\u0017.\u001a8u+\u0005a\u0006CA/a\u001b\u0005q&BA08\u00031y\u0007\u000f^5nSj\fG/[8o\u0013\t\tgL\u0001\u0005He\u0006$\u0017.\u001a8u\u0011\u001d\u0019\u0007\u00011A\u0005\n\u0011\fQbX4sC\u0012LWM\u001c;`I\u0015\fHC\u0001 f\u0011\u001d\u0011%-!AA\u0002qCaa\u001a\u0001!B\u0013a\u0016AC0he\u0006$\u0017.\u001a8uA!9\u0011\u000e\u0001a\u0001\n\u0013Q\u0017\u0001C0va\u0012\fG/\u001a:\u0016\u0003-\u0004\"!\u00187\n\u00055t&aB+qI\u0006$XM\u001d\u0005\b_\u0002\u0001\r\u0011\"\u0003q\u00031yV\u000f\u001d3bi\u0016\u0014x\fJ3r)\tq\u0014\u000fC\u0004C]\u0006\u0005\t\u0019A6\t\rM\u0004\u0001\u0015)\u0003l\u0003%yV\u000f\u001d3bi\u0016\u0014\b\u0005C\u0004v\u0001\u0001\u0007I\u0011\u0002<\u0002\u0013=\u0004H/[7ju\u0016\u0014X#A<\u0011\u0005uC\u0018BA=_\u0005%y\u0005\u000f^5nSj,'\u000fC\u0004|\u0001\u0001\u0007I\u0011\u0002?\u0002\u001b=\u0004H/[7ju\u0016\u0014x\fJ3r)\tqT\u0010C\u0004Cu\u0006\u0005\t\u0019A<\t\r}\u0004\u0001\u0015)\u0003x\u0003)y\u0007\u000f^5nSj,'\u000f\t\u0005\u0007\u0003\u0007\u0001A\u0011\u0001\u001a\u0002\u0015\u001d,GoV3jO\"$8\u000fC\u0004\u0002\b\u0001!\t!!\u0003\u0002\u0015M,GoV3jO\"$8\u000fF\u0002-\u0003\u0017Aq!!\u0004\u0002\u0006\u0001\u00071'A\u0003wC2,X\rC\u0004\u0002\u0012\u0001!\t!a\u0005\u0002\u0019M,Go\u0015;bG.\u001c\u0016N_3\u0015\u00071\n)\u0002C\u0004\u0002\u000e\u0005=\u0001\u0019\u0001\u0011\t\u000f\u0005e\u0001\u0001\"\u0001\u0002\u001c\u0005a1k\u0012#PaRLW.\u001b>feV\u0011\u0011Q\u0004\t\u0004;\u0006}\u0011bAA\u0011=\nyqI]1eS\u0016tG\u000fR3tG\u0016tG\u000fC\u0004\u0002&\u0001!\t!a\n\u0002\u001d1\u0013eiR*PaRLW.\u001b>feV\u0011\u0011\u0011\u0006\t\u0004;\u0006-\u0012bAA\u0017=\n)AJ\u0011$H'\"9\u0011\u0011\u0007\u0001\u0005\u0002\u0005M\u0012AC:fiV\u0003H-\u0019;feR\u0019A&!\u000e\t\u000f\u00055\u0011q\u0006a\u0001W\"9\u0011\u0011\b\u0001\u0005\u0002\u0005m\u0012aC:fi\u001e\u0013\u0018\rZ5f]R$2\u0001LA\u001f\u0011\u001d\ti!a\u000eA\u0002qC\u0001\"!\u0011\u0001A\u0013%\u00111I\u0001\u000fkB$\u0017\r^3He\u0006$\u0017.\u001a8u)\rq\u0014Q\t\u0005\b\u0003\u000f\ny\u00041\u0001]\u0003!9'/\u00193jK:$\b\u0002CA&\u0001\u0001&I!!\u0014\u0002\u001bU\u0004H-\u0019;f+B$\u0017\r^3s)\rq\u0014q\n\u0005\b\u0003#\nI\u00051\u0001l\u0003\u001d)\b\u000fZ1uKJDq!!\u0016\u0001\t\u0003\t9&A\u0003ue\u0006Lg\u000e\u0006\u0003\u0002Z\u0005}\u0003c\u0001\u000e\u0002\\%\u0019\u0011Q\f\u0002\u0003\u001bQ{\u0007o\u001c7pOflu\u000eZ3m\u0011!\t\t'a\u0015A\u0002\u0005\r\u0014\u0001\u00023bi\u0006\u0004b!!\u001a\u0002l\u0005=TBAA4\u0015\r\tIGB\u0001\u0004e\u0012$\u0017\u0002BA7\u0003O\u00121A\u0015#E!\u0015q\u0011\u0011O\u001a4\u0013\r\t\u0019h\u0004\u0002\u0007)V\u0004H.\u001a\u001a")
public class FeedForwardTrainer
implements Serializable {
    private final Topology topology;
    private final int inputSize;
    private final int outputSize;
    private Vector _weights;
    private int _stackSize;
    private DataStacker dataStacker;
    private Gradient _gradient;
    private Updater _updater;
    private Optimizer optimizer;

    public int inputSize() {
        return this.inputSize;
    }

    public int outputSize() {
        return this.outputSize;
    }

    private Vector _weights() {
        return this._weights;
    }

    private void _weights_$eq(Vector x$1) {
        this._weights = x$1;
    }

    private int _stackSize() {
        return this._stackSize;
    }

    private void _stackSize_$eq(int x$1) {
        this._stackSize = x$1;
    }

    private DataStacker dataStacker() {
        return this.dataStacker;
    }

    private void dataStacker_$eq(DataStacker x$1) {
        this.dataStacker = x$1;
    }

    private Gradient _gradient() {
        return this._gradient;
    }

    private void _gradient_$eq(Gradient x$1) {
        this._gradient = x$1;
    }

    private Updater _updater() {
        return this._updater;
    }

    private void _updater_$eq(Updater x$1) {
        this._updater = x$1;
    }

    private Optimizer optimizer() {
        return this.optimizer;
    }

    private void optimizer_$eq(Optimizer x$1) {
        this.optimizer = x$1;
    }

    public Vector getWeights() {
        return this._weights();
    }

    public FeedForwardTrainer setWeights(Vector value) {
        this._weights_$eq(value);
        return this;
    }

    public FeedForwardTrainer setStackSize(int value) {
        this._stackSize_$eq(value);
        this.dataStacker_$eq(new DataStacker(value, this.inputSize(), this.outputSize()));
        return this;
    }

    /*
     * WARNING - void declaration
     */
    public GradientDescent SGDOptimizer() {
        void var1_1;
        GradientDescent sgd = new GradientDescent(this._gradient(), this._updater());
        this.optimizer_$eq(sgd);
        return var1_1;
    }

    /*
     * WARNING - void declaration
     */
    public LBFGS LBFGSOptimizer() {
        void var1_1;
        LBFGS lbfgs = new LBFGS(this._gradient(), this._updater());
        this.optimizer_$eq(lbfgs);
        return var1_1;
    }

    public FeedForwardTrainer setUpdater(Updater value) {
        this._updater_$eq(value);
        this.updateUpdater(value);
        return this;
    }

    public FeedForwardTrainer setGradient(Gradient value) {
        this._gradient_$eq(value);
        this.updateGradient(value);
        return this;
    }

    private void updateGradient(Gradient gradient2) {
        Optimizer optimizer;
        block4: {
            block3: {
                block2: {
                    optimizer = this.optimizer();
                    if (!(optimizer instanceof LBFGS)) break block2;
                    LBFGS lBFGS = (LBFGS)optimizer;
                    lBFGS.setGradient(gradient2);
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                    break block3;
                }
                if (!(optimizer instanceof GradientDescent)) break block4;
                GradientDescent gradientDescent = (GradientDescent)optimizer;
                gradientDescent.setGradient(gradient2);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            }
            return;
        }
        throw new UnsupportedOperationException(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Only LBFGS and GradientDescent are supported but got ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{optimizer.getClass()})));
    }

    private void updateUpdater(Updater updater) {
        Optimizer optimizer;
        block4: {
            block3: {
                block2: {
                    optimizer = this.optimizer();
                    if (!(optimizer instanceof LBFGS)) break block2;
                    LBFGS lBFGS = (LBFGS)optimizer;
                    lBFGS.setUpdater(updater);
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                    break block3;
                }
                if (!(optimizer instanceof GradientDescent)) break block4;
                GradientDescent gradientDescent = (GradientDescent)optimizer;
                gradientDescent.setUpdater(updater);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            }
            return;
        }
        throw new UnsupportedOperationException(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Only LBFGS and GradientDescent are supported but got ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{optimizer.getClass()})));
    }

    public TopologyModel train(RDD<Tuple2<Vector, Vector>> data) {
        Vector newWeights = this.optimizer().optimize(this.dataStacker().stack(data), this.getWeights());
        return this.topology.getInstance(newWeights);
    }

    public FeedForwardTrainer(Topology topology, int inputSize, int outputSize) {
        this.topology = topology;
        this.inputSize = inputSize;
        this.outputSize = outputSize;
        this._weights = topology.getInstance(11L).weights();
        this._stackSize = 128;
        this.dataStacker = new DataStacker(this._stackSize(), inputSize, outputSize);
        this._gradient = new ANNGradient(topology, this.dataStacker());
        this._updater = new ANNUpdater();
        this.optimizer = this.LBFGSOptimizer().setConvergenceTol(1.0E-4).setNumIterations(100);
    }
}

