package cc.factorie.app.nlp.parse;

import cc.factorie.app.classify.backend.LinearMulticlassClassifier;
import cc.factorie.app.classify.backend.MulticlassClassifierTrainer;
import cc.factorie.app.classify.backend.OnlineLinearMulticlassTrainer;
import cc.factorie.app.classify.backend.OnlineLinearMulticlassTrainer$;
import cc.factorie.app.classify.backend.SVMMulticlassTrainer;
import cc.factorie.app.classify.backend.SVMMulticlassTrainer$;
import cc.factorie.app.nlp.Sentence;
import cc.factorie.la.Tensor2;
import cc.factorie.optimize.AdaGradRDA;
import cc.factorie.optimize.AdaGradRDA$;
import cc.factorie.optimize.OptimizableObjectives$;
import cc.factorie.package$;
import cc.factorie.util.BoxedDouble;
import cc.factorie.util.CmdOptions;
import cc.factorie.util.FileUtils$;
import cc.factorie.util.HyperparameterMain;
import cc.factorie.variable.CategoricalDomain;
import java.io.File;
import scala.Predef$;
import scala.StringContext;
import scala.collection.Iterable;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.StringBuilder;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichDouble$;
import scala.runtime.RichInt$;
import scala.util.Random;

/* compiled from: TransitionBasedParser.scala */
/* loaded from: input_file:cc/factorie/app/nlp/parse/TransitionBasedParserTrainer$.class */
public final class TransitionBasedParserTrainer$ implements HyperparameterMain {
    public static final TransitionBasedParserTrainer$ MODULE$ = null;

    static {
        new TransitionBasedParserTrainer$();
    }

    @Override // cc.factorie.util.HyperparameterMain
    public final void main(String[] strArr) {
        HyperparameterMain.Cclass.main(this, strArr);
    }

    @Override // cc.factorie.util.HyperparameterMain
    public final BoxedDouble actualMain(String[] strArr) {
        return HyperparameterMain.Cclass.actualMain(this, strArr);
    }

    @Override // cc.factorie.util.HyperparameterMain
    public double evaluateParameters(String[] strArr) {
        MulticlassClassifierTrainer<LinearMulticlassClassifier> onlineLinearMulticlassTrainer;
        TransitionBasedParserArgs transitionBasedParserArgs = new TransitionBasedParserArgs();
        Random random = new Random(0);
        transitionBasedParserArgs.parse(Predef$.MODULE$.wrapRefArray(strArr));
        Predef$.MODULE$.assert(transitionBasedParserArgs.trainFiles().wasInvoked() || transitionBasedParserArgs.trainDir().wasInvoked());
        Seq loadSentences$1 = loadSentences$1(transitionBasedParserArgs.trainFiles(), transitionBasedParserArgs.trainDir(), transitionBasedParserArgs);
        Seq loadSentences$12 = loadSentences$1(transitionBasedParserArgs.devFiles(), transitionBasedParserArgs.devDir(), transitionBasedParserArgs);
        Seq loadSentences$13 = loadSentences$1(transitionBasedParserArgs.testFiles(), transitionBasedParserArgs.testDir(), transitionBasedParserArgs);
        double unboxToDouble = transitionBasedParserArgs.trainPortion().wasInvoked() ? BoxesRunTime.unboxToDouble(transitionBasedParserArgs.trainPortion().value()) : 1.0d;
        double unboxToDouble2 = transitionBasedParserArgs.testPortion().wasInvoked() ? BoxesRunTime.unboxToDouble(transitionBasedParserArgs.testPortion().value()) : 1.0d;
        Iterable<Sentence> iterable = (Seq) loadSentences$1.take((int) RichDouble$.MODULE$.floor$extension(Predef$.MODULE$.doubleWrapper(unboxToDouble * loadSentences$1.length())));
        Seq seq = (Seq) loadSentences$13.take((int) RichDouble$.MODULE$.floor$extension(Predef$.MODULE$.doubleWrapper(unboxToDouble2 * loadSentences$13.length())));
        Seq seq2 = (Seq) loadSentences$12.take((int) RichDouble$.MODULE$.floor$extension(Predef$.MODULE$.doubleWrapper(unboxToDouble2 * loadSentences$12.length())));
        Predef$.MODULE$.println(new StringBuilder().append("Total train sentences: ").append(BoxesRunTime.boxToInteger(iterable.size())).toString());
        Predef$.MODULE$.println(new StringBuilder().append("Total test sentences: ").append(BoxesRunTime.boxToInteger(seq.size())).toString());
        int unboxToInt = BoxesRunTime.unboxToInt(transitionBasedParserArgs.bootstrapping().value());
        TransitionBasedParser transitionBasedParser = new TransitionBasedParser();
        AdaGradRDA adaGradRDA = new AdaGradRDA(BoxesRunTime.unboxToDouble(transitionBasedParserArgs.rate().value()), BoxesRunTime.unboxToDouble(transitionBasedParserArgs.delta().value()), (2 * BoxesRunTime.unboxToDouble(transitionBasedParserArgs.l1().value())) / iterable.length(), (2 * BoxesRunTime.unboxToDouble(transitionBasedParserArgs.l2().value())) / iterable.length(), AdaGradRDA$.MODULE$.$lessinit$greater$default$5());
        Predef$.MODULE$.println(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Initializing trainer (", " threads)"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{transitionBasedParserArgs.nThreads().value()})));
        if (BoxesRunTime.unboxToBoolean(transitionBasedParserArgs.useSVM().value())) {
            onlineLinearMulticlassTrainer = new SVMMulticlassTrainer(BoxesRunTime.unboxToInt(transitionBasedParserArgs.nThreads().value()), SVMMulticlassTrainer$.MODULE$.$lessinit$greater$default$2(), random);
        } else {
            onlineLinearMulticlassTrainer = new OnlineLinearMulticlassTrainer(BoxesRunTime.unboxToInt(transitionBasedParserArgs.nThreads().value()) > 1, adaGradRDA, OptimizableObjectives$.MODULE$.hingeMulticlass(), BoxesRunTime.unboxToInt(transitionBasedParserArgs.maxIters().value()), OnlineLinearMulticlassTrainer$.MODULE$.$lessinit$greater$default$5(), BoxesRunTime.unboxToInt(transitionBasedParserArgs.nThreads().value()), random);
        }
        MulticlassClassifierTrainer<LinearMulticlassClassifier> multiclassClassifierTrainer = onlineLinearMulticlassTrainer;
        transitionBasedParser.featuresDomain().dimensionDomain().gatherCounts_$eq(true);
        Predef$.MODULE$.println("Generating decisions...");
        transitionBasedParser.generateDecisions(iterable, transitionBasedParser.ParserConstants().TRAINING(), BoxesRunTime.unboxToInt(transitionBasedParserArgs.nThreads().value()));
        Predef$.MODULE$.println(new StringBuilder().append("Before pruning # features ").append(BoxesRunTime.boxToInteger(transitionBasedParser.featuresDomain().dimensionDomain().size())).toString());
        CategoricalDomain<String> dimensionDomain = transitionBasedParser.featuresDomain().dimensionDomain();
        dimensionDomain.trimBelowCount(2 * BoxesRunTime.unboxToInt(transitionBasedParserArgs.cutoff().value()), dimensionDomain.trimBelowCount$default$2());
        transitionBasedParser.featuresDomain().freeze();
        transitionBasedParser.featuresDomain().dimensionDomain().gatherCounts_$eq(false);
        Predef$.MODULE$.println(new StringBuilder().append("After pruning # features ").append(BoxesRunTime.boxToInteger(transitionBasedParser.featuresDomain().dimensionDomain().size())).toString());
        Predef$.MODULE$.println("Training...");
        transitionBasedParser.trainFromVariables(transitionBasedParser.generateDecisions(iterable, transitionBasedParser.ParserConstants().TRAINING(), BoxesRunTime.unboxToInt(transitionBasedParserArgs.nThreads().value())), multiclassClassifierTrainer, new TransitionBasedParserTrainer$$anonfun$evaluateParameters$2(iterable, seq, seq2, transitionBasedParser));
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), unboxToInt).foreach$mVc$sp(new TransitionBasedParserTrainer$$anonfun$evaluateParameters$1(transitionBasedParserArgs, iterable, seq, seq2, transitionBasedParser, multiclassClassifierTrainer));
        if (BoxesRunTime.unboxToBoolean(transitionBasedParserArgs.saveModel().value())) {
            String value = transitionBasedParserArgs.modelDir().wasInvoked() ? transitionBasedParserArgs.modelDir().value() : new StringBuilder().append(transitionBasedParserArgs.modelDir().defaultValue()).append(BoxesRunTime.boxToLong(System.currentTimeMillis()).toString()).append(".factorie").toString();
            transitionBasedParser.serialize(new File(value));
            TransitionBasedParser transitionBasedParser2 = new TransitionBasedParser();
            transitionBasedParser2.deserialize(new File(value));
            testSingle$1(transitionBasedParser2, seq, "Post serialization accuracy ");
        }
        double calcLas = ParserEval$.MODULE$.calcLas((Iterable) seq.map(new TransitionBasedParserTrainer$$anonfun$32(), Seq$.MODULE$.canBuildFrom()), ParserEval$.MODULE$.calcLas$default$2());
        if (transitionBasedParserArgs.targetAccuracy().wasInvoked()) {
            package$.MODULE$.assertMinimalAccuracy(calcLas, new StringOps(Predef$.MODULE$.augmentString(transitionBasedParserArgs.targetAccuracy().value())).toDouble());
        }
        return calcLas;
    }

    private final Seq loadSentences$1(CmdOptions.CmdOption cmdOption, CmdOptions.CmdOption cmdOption2, TransitionBasedParserArgs transitionBasedParserArgs) {
        scala.collection.immutable.Seq seq = (Seq) Seq$.MODULE$.empty();
        if (cmdOption.wasInvoked()) {
            seq = ((scala.collection.immutable.Seq) cmdOption.value()).toSeq();
        }
        if (cmdOption2.wasInvoked()) {
            seq = (Seq) seq.$plus$plus(FileUtils$.MODULE$.getFileListFromDir((String) cmdOption2.value(), FileUtils$.MODULE$.getFileListFromDir$default$2()), Seq$.MODULE$.canBuildFrom());
        }
        return (Seq) seq.flatMap(new TransitionBasedParserTrainer$$anonfun$loadSentences$1$1(transitionBasedParserArgs), Seq$.MODULE$.canBuildFrom());
    }

    private final void testSingle$1(TransitionBasedParser transitionBasedParser, Seq seq, String str) {
        if (seq.nonEmpty()) {
            Predef$.MODULE$.println(new StringBuilder().append(str).append(" ").append(transitionBasedParser.testString(seq)).toString());
        }
    }

    private final String testSingle$default$3$1() {
        return "";
    }

    private final void testAll$1(TransitionBasedParser transitionBasedParser, String str, Seq seq, Seq seq2, Seq seq3) {
        Predef$.MODULE$.println("\n");
        testSingle$1(transitionBasedParser, seq, new StringBuilder().append("Train ").append(str).toString());
        testSingle$1(transitionBasedParser, seq3, new StringBuilder().append("Dev ").append(str).toString());
        testSingle$1(transitionBasedParser, seq2, new StringBuilder().append("Test ").append(str).toString());
    }

    private final String testAll$default$2$1() {
        return "";
    }

    public final void cc$factorie$app$nlp$parse$TransitionBasedParserTrainer$$evaluate$2(LinearMulticlassClassifier linearMulticlassClassifier, Seq seq, Seq seq2, Seq seq3, TransitionBasedParser transitionBasedParser) {
        Predef$.MODULE$.println(new StringBuilder().append(linearMulticlassClassifier.weights().mo121value().mo2059toSeq().count(new TransitionBasedParserTrainer$$anonfun$cc$factorie$app$nlp$parse$TransitionBasedParserTrainer$$evaluate$2$1()) / ((Tensor2) linearMulticlassClassifier.weights().mo121value()).length()).append(" sparsity").toString());
        testAll$1(transitionBasedParser, "iteration ", seq, seq2, seq3);
    }

    private TransitionBasedParserTrainer$() {
        MODULE$ = this;
        HyperparameterMain.Cclass.$init$(this);
    }
}
