/*
 * Decompiled with CFR 0.152.
 */
package com.johnsnowlabs.ml.crf;

import com.johnsnowlabs.ml.crf.CrfDataset;
import com.johnsnowlabs.ml.crf.CrfParams;
import com.johnsnowlabs.ml.crf.DatasetMetadata;
import com.johnsnowlabs.ml.crf.FbCalculator;
import com.johnsnowlabs.ml.crf.Instance;
import com.johnsnowlabs.ml.crf.InstanceLabels;
import com.johnsnowlabs.ml.crf.L2DecayStrategy;
import com.johnsnowlabs.ml.crf.LinearChainCrf$;
import com.johnsnowlabs.ml.crf.LinearChainCrfModel;
import com.johnsnowlabs.ml.crf.VectorMath$;
import com.johnsnowlabs.nlp.annotators.ner.Verbose$;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Array$;
import scala.Enumeration;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.math.Numeric;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.FloatRef;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;
import scala.util.Random$;

@ScalaSignature(bytes="\u0006\u0001q4A!\u0001\u0002\u0001\u0017\tqA*\u001b8fCJ\u001c\u0005.Y5o\u0007J4'BA\u0002\u0005\u0003\r\u0019'O\u001a\u0006\u0003\u000b\u0019\t!!\u001c7\u000b\u0005\u001dA\u0011\u0001\u00046pQ:\u001chn\\<mC\n\u001c(\"A\u0005\u0002\u0007\r|Wn\u0001\u0001\u0014\u0005\u0001a\u0001CA\u0007\u0011\u001b\u0005q!\"A\b\u0002\u000bM\u001c\u0017\r\\1\n\u0005Eq!AB!osJ+g\r\u0003\u0005\u0014\u0001\t\u0015\r\u0011\"\u0001\u0015\u0003\u0019\u0001\u0018M]1ngV\tQ\u0003\u0005\u0002\u0017/5\t!!\u0003\u0002\u0019\u0005\tI1I\u001d4QCJ\fWn\u001d\u0005\t5\u0001\u0011\t\u0011)A\u0005+\u00059\u0001/\u0019:b[N\u0004\u0003\"\u0002\u000f\u0001\t\u0003i\u0012A\u0002\u001fj]&$h\b\u0006\u0002\u001f?A\u0011a\u0003\u0001\u0005\u0006'm\u0001\r!\u0006\u0005\bC\u0001\u0011\r\u0011\"\u0003#\u0003\u0019awnZ4feV\t1\u0005\u0005\u0002%S5\tQE\u0003\u0002'O\u0005)1\u000f\u001c45U*\t\u0001&A\u0002pe\u001eL!AK\u0013\u0003\r1{wmZ3s\u0011\u0019a\u0003\u0001)A\u0005G\u00059An\\4hKJ\u0004\u0003\"\u0002\u0018\u0001\t\u0003y\u0013a\u00017pOR\u0019\u0001gM \u0011\u00055\t\u0014B\u0001\u001a\u000f\u0005\u0011)f.\u001b;\t\rQjC\u00111\u00016\u0003\u00151\u0018\r\\;f!\ria\u0007O\u0005\u0003o9\u0011\u0001\u0002\u00102z]\u0006lWM\u0010\t\u0003sqr!!\u0004\u001e\n\u0005mr\u0011A\u0002)sK\u0012,g-\u0003\u0002>}\t11\u000b\u001e:j]\u001eT!a\u000f\b\t\u000b\u0001k\u0003\u0019A!\u0002\u00115Lg\u000eT3wK2\u0004\"A\u0011'\u000f\u0005\rSU\"\u0001#\u000b\u0005\u00153\u0015a\u00018fe*\u0011q\tS\u0001\u000bC:tw\u000e^1u_J\u001c(BA%\u0007\u0003\rqG\u000e]\u0005\u0003\u0017\u0012\u000bqAV3sE>\u001cX-\u0003\u0002N\u001d\n)A*\u001a<fY*\u00111\n\u0012\u0005\u0006!\u0002!\t!U\u0001\tiJ\f\u0017N\\*H\tR\u0011!+\u0016\t\u0003-MK!\u0001\u0016\u0002\u0003'1Kg.Z1s\u0007\"\f\u0017N\\\"sM6{G-\u001a7\t\u000bY{\u0005\u0019A,\u0002\u000f\u0011\fG/Y:fiB\u0011a\u0003W\u0005\u00033\n\u0011!b\u0011:g\t\u0006$\u0018m]3u\u0011\u0015Y\u0006\u0001\"\u0003]\u0003\u001d9W\r\u001e'pgN$B!\u00181fUB\u0011QBX\u0005\u0003?:\u0011QA\u00127pCRDQ!\u0019.A\u0002\t\f\u0001b]3oi\u0016t7-\u001a\t\u0003-\rL!\u0001\u001a\u0002\u0003\u0011%s7\u000f^1oG\u0016DQA\u001a.A\u0002\u001d\fa\u0001\\1cK2\u001c\bC\u0001\fi\u0013\tI'A\u0001\bJ]N$\u0018M\\2f\u0019\u0006\u0014W\r\\:\t\u000b-T\u0006\u0019\u00017\u0002\u000f\r|g\u000e^3yiB\u0011a#\\\u0005\u0003]\n\u0011AB\u00122DC2\u001cW\u000f\\1u_JDQ\u0001\u001d\u0001\u0005\u0002E\f\u0011\u0002Z8TO\u0012\u001cF/\u001a9\u0015\rA\u00128\u000f\u001e<|\u0011\u0015\tw\u000e1\u0001c\u0011\u00151w\u000e1\u0001h\u0011\u0015)x\u000e1\u0001^\u0003\u0005\t\u0007\"B<p\u0001\u0004A\u0018aB<fS\u001eDGo\u001d\t\u0004\u001bel\u0016B\u0001>\u000f\u0005\u0015\t%O]1z\u0011\u0015Yw\u000e1\u0001m\u0001")
public class LinearChainCrf {
    private final CrfParams params;
    private final Logger logger;

    public CrfParams params() {
        return this.params;
    }

    private Logger logger() {
        return this.logger;
    }

    public void log(Function0<String> value, Enumeration.Value minLevel) {
        if (minLevel.$greater$eq((Object)this.params().verbose())) {
            this.logger().info((String)value.apply());
        }
    }

    public LinearChainCrfModel trainSGD(CrfDataset dataset) {
        DatasetMetadata metadata = dataset.metadata();
        float[] weights = VectorMath$.MODULE$.Vector(dataset.metadata().attrFeatures().length + dataset.metadata().transitions().length, VectorMath$.MODULE$.Vector$default$2());
        int labels = dataset.metadata().labels().length;
        if (this.params().randomSeed().isDefined()) {
            Random$.MODULE$.setSeed((long)BoxesRunTime.unboxToInt((Object)this.params().randomSeed().get()));
        }
        int maxLength = BoxesRunTime.unboxToInt((Object)((TraversableOnce)dataset.instances().map((Function1)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final int apply(Tuple2<InstanceLabels, Instance> w) {
                return ((Instance)w._2()).items().size();
            }
        }, Seq$.MODULE$.canBuildFrom())).max((Ordering)Ordering.Int$.MODULE$));
        this.log((Function0<String>)new Serializable(this, labels){
            public static final long serialVersionUID = 0L;
            private final int labels$1;

            public final String apply() {
                return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"labels: ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.labels$1)}));
            }
            {
                this.labels$1 = labels$1;
            }
        }, Verbose$.MODULE$.TrainingStat());
        this.log((Function0<String>)new Serializable(this, dataset){
            public static final long serialVersionUID = 0L;
            private final CrfDataset dataset$1;

            public final String apply() {
                return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"instances: ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.dataset$1.instances().size())}));
            }
            {
                this.dataset$1 = dataset$1;
            }
        }, Verbose$.MODULE$.TrainingStat());
        this.log((Function0<String>)new Serializable(this, weights){
            public static final long serialVersionUID = 0L;
            private final float[] weights$1;

            public final String apply() {
                return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"features: ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.weights$1.length)}));
            }
            {
                this.weights$1 = weights$1;
            }
        }, Verbose$.MODULE$.TrainingStat());
        this.log((Function0<String>)new Serializable(this, maxLength){
            public static final long serialVersionUID = 0L;
            private final int maxLength$1;

            public final String apply() {
                return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"maxLength: ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.maxLength$1)}));
            }
            {
                this.maxLength$1 = maxLength$1;
            }
        }, Verbose$.MODULE$.TrainingStat());
        FbCalculator context = new FbCalculator(maxLength, metadata);
        float[] bestW = VectorMath$.MODULE$.Vector(weights.length, VectorMath$.MODULE$.Vector$default$2());
        FloatRef bestLoss = FloatRef.create((float)Float.MAX_VALUE);
        FloatRef lastLoss = FloatRef.create((float)Float.MAX_VALUE);
        IntRef notImprovedEpochs = IntRef.create((int)0);
        L2DecayStrategy decayStrategy = new L2DecayStrategy(dataset.instances().size(), this.params().l2(), this.params().c0());
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.params().maxEpochs()).withFilter((Function1)new Serializable(this, notImprovedEpochs){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ LinearChainCrf $outer;
            private final IntRef notImprovedEpochs$1;

            public final boolean apply(int epoch) {
                return this.apply$mcZI$sp(epoch);
            }

            public boolean apply$mcZI$sp(int epoch) {
                return this.notImprovedEpochs$1.elem < 10 || epoch < this.$outer.params().minEpochs();
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
                this.notImprovedEpochs$1 = notImprovedEpochs$1;
            }
        }).foreach((Function1)new Serializable(this, dataset, weights, context, bestW, bestLoss, lastLoss, notImprovedEpochs, decayStrategy){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ LinearChainCrf $outer;
            private final CrfDataset dataset$1;
            public final float[] weights$1;
            public final FbCalculator context$1;
            private final float[] bestW$1;
            private final FloatRef bestLoss$1;
            private final FloatRef lastLoss$1;
            private final IntRef notImprovedEpochs$1;
            public final L2DecayStrategy decayStrategy$1;

            public final void apply(int epoch) {
                this.apply$mcVI$sp(epoch);
            }

            public void apply$mcVI$sp(int epoch) {
                FloatRef loss = FloatRef.create((float)0.0f);
                this.$outer.log((Function0<String>)new Serializable(this, epoch){
                    public static final long serialVersionUID = 0L;
                    private final /* synthetic */ $anonfun$trainSGD$2 $outer;
                    private final int epoch$1;

                    public final String apply() {
                        return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"\\nEpoch: ", ", eta: ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.epoch$1), BoxesRunTime.boxToFloat((float)this.$outer.decayStrategy$1.eta())}));
                    }
                    {
                        if ($outer == null) {
                            throw null;
                        }
                        this.$outer = $outer;
                        this.epoch$1 = epoch$1;
                    }
                }, Verbose$.MODULE$.Epochs());
                long started = System.nanoTime();
                Seq shuffled = (Seq)Random$.MODULE$.shuffle(this.dataset$1.instances(), Seq$.MODULE$.canBuildFrom());
                IntRef instancesCount = IntRef.create((int)0);
                shuffled.withFilter((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final boolean apply(Tuple2<InstanceLabels, Instance> check$ifrefutable$1) {
                        Tuple2<InstanceLabels, Instance> tuple2 = check$ifrefutable$1;
                        boolean bl = tuple2 != null;
                        return bl;
                    }
                }).foreach((Function1)new Serializable(this, loss, instancesCount){
                    public static final long serialVersionUID = 0L;
                    private final /* synthetic */ $anonfun$trainSGD$2 $outer;
                    private final FloatRef loss$1;
                    private final IntRef instancesCount$1;

                    public final void apply(Tuple2<InstanceLabels, Instance> x$1) {
                        Tuple2<InstanceLabels, Instance> tuple2 = x$1;
                        if (tuple2 != null) {
                            BoxedUnit boxedUnit;
                            InstanceLabels labels = (InstanceLabels)tuple2._1();
                            Instance sentence = (Instance)tuple2._2();
                            this.$outer.decayStrategy$1.nextStep();
                            this.$outer.context$1.calculate(sentence, this.$outer.weights$1, this.$outer.decayStrategy$1.getScale());
                            this.$outer.com$johnsnowlabs$ml$crf$LinearChainCrf$$anonfun$$$outer().doSgdStep(sentence, labels, this.$outer.decayStrategy$1.alpha(), this.$outer.weights$1, this.$outer.context$1);
                            this.loss$1.elem += this.$outer.com$johnsnowlabs$ml$crf$LinearChainCrf$$anonfun$$$outer().com$johnsnowlabs$ml$crf$LinearChainCrf$$getLoss(sentence, labels, this.$outer.context$1);
                            ++this.instancesCount$1.elem;
                            if (this.instancesCount$1.elem % 1000 == 0) {
                                this.$outer.decayStrategy$1.reset(this.$outer.weights$1);
                                boxedUnit = BoxedUnit.UNIT;
                            } else {
                                boxedUnit = BoxedUnit.UNIT;
                            }
                            BoxedUnit boxedUnit2 = boxedUnit;
                            return;
                        }
                        throw new MatchError(tuple2);
                    }
                    {
                        if ($outer == null) {
                            throw null;
                        }
                        this.$outer = $outer;
                        this.loss$1 = loss$1;
                        this.instancesCount$1 = instancesCount$1;
                    }
                });
                this.decayStrategy$1.reset(this.weights$1);
                float l2Loss = this.$outer.params().l2() * BoxesRunTime.unboxToFloat((Object)Predef$.MODULE$.floatArrayOps((float[])Predef$.MODULE$.floatArrayOps(this.weights$1).map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final float apply(float w) {
                        return this.apply$mcFF$sp(w);
                    }

                    public float apply$mcFF$sp(float w) {
                        return w * w;
                    }
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Float()))).sum((Numeric)Numeric.FloatIsFractional$.MODULE$));
                float totalLoss = loss.elem + l2Loss;
                this.$outer.log((Function0<String>)new Serializable(this, started){
                    public static final long serialVersionUID = 0L;
                    private final long started$1;

                    public final String apply() {
                        return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"finished, time: ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)((double)(System.nanoTime() - this.started$1) / 1.0E9))}));
                    }
                    {
                        this.started$1 = started$1;
                    }
                }, Verbose$.MODULE$.Epochs());
                this.$outer.log((Function0<String>)new Serializable(this, loss, l2Loss, totalLoss){
                    public static final long serialVersionUID = 0L;
                    private final FloatRef loss$1;
                    private final float l2Loss$1;
                    private final float totalLoss$1;

                    public final String apply() {
                        return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Loss = ", ", logLoss = ", ", l2Loss = ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToFloat((float)this.totalLoss$1), BoxesRunTime.boxToFloat((float)this.loss$1.elem), BoxesRunTime.boxToFloat((float)this.l2Loss$1)}));
                    }
                    {
                        this.loss$1 = loss$1;
                        this.l2Loss$1 = l2Loss$1;
                        this.totalLoss$1 = totalLoss$1;
                    }
                }, Verbose$.MODULE$.Epochs());
                if (totalLoss < this.bestLoss$1.elem) {
                    this.bestLoss$1.elem = totalLoss;
                    VectorMath$.MODULE$.copy(this.weights$1, this.bestW$1);
                    this.notImprovedEpochs$1.elem = (this.bestLoss$1.elem - totalLoss) / totalLoss < this.$outer.params().lossEps() ? 0 : ++this.notImprovedEpochs$1.elem;
                } else {
                    ++this.notImprovedEpochs$1.elem;
                }
                this.lastLoss$1.elem = totalLoss;
            }

            public /* synthetic */ LinearChainCrf com$johnsnowlabs$ml$crf$LinearChainCrf$$anonfun$$$outer() {
                return this.$outer;
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
                this.dataset$1 = dataset$1;
                this.weights$1 = weights$1;
                this.context$1 = context$1;
                this.bestW$1 = bestW$1;
                this.bestLoss$1 = bestLoss$1;
                this.lastLoss$1 = lastLoss$1;
                this.notImprovedEpochs$1 = notImprovedEpochs$1;
                this.decayStrategy$1 = decayStrategy$1;
            }
        });
        return new LinearChainCrfModel(bestW, metadata);
    }

    public float com$johnsnowlabs$ml$crf$LinearChainCrf$$getLoss(Instance sentence, InstanceLabels labels, FbCalculator context) {
        int length = sentence.items().length();
        IntRef prevLabel = IntRef.create((int)0);
        FloatRef result = FloatRef.create((float)0.0f);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), length).foreach$mVc$sp((Function1)new Serializable(this, labels, context, prevLabel, result){
            public static final long serialVersionUID = 0L;
            private final InstanceLabels labels$2;
            private final FbCalculator context$2;
            private final IntRef prevLabel$1;
            private final FloatRef result$1;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                this.result$1.elem -= this.context$2.logPhi()[i][this.prevLabel$1.elem][BoxesRunTime.unboxToInt((Object)this.labels$2.labels().apply(i))];
                this.prevLabel$1.elem = BoxesRunTime.unboxToInt((Object)this.labels$2.labels().apply(i));
                this.result$1.elem += (float)Math.log(this.context$2.c()[i]);
            }
            {
                this.labels$2 = labels$2;
                this.context$2 = context$2;
                this.prevLabel$1 = prevLabel$1;
                this.result$1 = result$1;
            }
        });
        if (result.elem >= 0.0f) {
            Predef$.MODULE$.assert(result.elem >= 0.0f);
        }
        return result.elem;
    }

    public void doSgdStep(Instance sentence, InstanceLabels labels, float a, float[] weights, FbCalculator context) {
        context.addObservedExpectations(weights, sentence, labels, a);
        context.addModelExpectations(weights, sentence, -a);
    }

    public LinearChainCrf(CrfParams params) {
        this.params = params;
        this.logger = LoggerFactory.getLogger((String)"CRF");
    }
}

