package jsat.classifiers.bayesian;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.exceptions.FailedToFitException;

/* loaded from: input_file:jsat/classifiers/bayesian/AODE.class */
public class AODE extends BaseUpdateableClassifier {
    private static final long serialVersionUID = 8386506277969540732L;
    protected CategoricalData predicting;
    protected ODE[] odes;
    private double m;

    public AODE() {
        this.m = 20.0d;
    }

    protected AODE(AODE aode) {
        this.m = 20.0d;
        if (aode.odes != null) {
            this.odes = new ODE[aode.odes.length];
            for (int i = 0; i < this.odes.length; i++) {
                this.odes[i] = aode.odes[i].mo0clone();
            }
            this.predicting = aode.predicting.m1clone();
        }
        this.m = aode.m;
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone */
    public AODE mo0clone() {
        return new AODE(this);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        if (categoricalDataArr.length < 1) {
            throw new FailedToFitException("At least 2 categorical varaibles are needed for AODE");
        }
        this.predicting = categoricalData;
        this.odes = new ODE[categoricalDataArr.length];
        for (int i2 = 0; i2 < this.odes.length; i2++) {
            this.odes[i2] = new ODE(i2);
            this.odes[i2].setUp(categoricalDataArr, i, categoricalData);
        }
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier, jsat.classifiers.Classifier
    public void trainC(final ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        setUp(classificationDataSet.getCategories(), classificationDataSet.getNumNumericalVars(), classificationDataSet.getPredicting());
        final CountDownLatch countDownLatch = new CountDownLatch(this.odes.length);
        for (int i = 0; i < this.odes.length; i++) {
            final ODE ode = this.odes[i];
            executorService.submit(new Runnable() { // from class: jsat.classifiers.bayesian.AODE.1
                @Override // java.lang.Runnable
                public void run() {
                    for (int i2 = 0; i2 < classificationDataSet.getSampleSize(); i2++) {
                        ode.update(classificationDataSet.getDataPoint(i2), classificationDataSet.getDataPointCategory(i2));
                    }
                    countDownLatch.countDown();
                }
            });
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            trainC(classificationDataSet);
        }
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        for (ODE ode : this.odes) {
            ode.update(dataPoint, i);
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
        int[] categoricalValues = dataPoint.getCategoricalValues();
        for (int i = 0; i < categoricalResults.size(); i++) {
            double d = 0.0d;
            for (ODE ode : this.odes) {
                if (ode.priors[i][categoricalValues[ode.dependent]] >= this.m) {
                    d += Math.exp(ode.getLogPrb(categoricalValues, i));
                }
            }
            categoricalResults.setProb(i, d);
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    public void setM(double d) {
        if (d < 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new ArithmeticException("The minimum count must be a non negative number");
        }
        this.m = d;
    }

    public double getM() {
        return this.m;
    }
}
