package org.apache.ignite.ml.svm;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Objects;
import java.util.UUID;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.inference.json.JSONModel;
import org.apache.ignite.ml.inference.json.JSONWritable;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;

/* loaded from: input_file:org/apache/ignite/ml/svm/SVMLinearClassificationModel.class */
public final class SVMLinearClassificationModel implements IgniteModel<Vector, Double>, Exportable<SVMLinearClassificationModel>, JSONWritable {
    private static final long serialVersionUID = -996984622291440226L;
    private boolean isKeepingRawLabels;
    private double threshold = 0.5d;
    private Vector weights;
    private double intercept;

    /* loaded from: input_file:org/apache/ignite/ml/svm/SVMLinearClassificationModel$SVMLinearClassificationJSONExportModel.class */
    public static class SVMLinearClassificationJSONExportModel extends JSONModel {
        public double[] weights;
        public double intercept;
        public boolean isKeepingRawLabels;
        public double threshold;

        public SVMLinearClassificationJSONExportModel(Long l, String str, String str2) {
            super(l, str, str2);
            this.threshold = 0.5d;
        }

        @JsonCreator
        public SVMLinearClassificationJSONExportModel() {
            this.threshold = 0.5d;
        }

        public String toString() {
            return "SVMLinearClassificationJSONExportModel{weights=" + Arrays.toString(this.weights) + ", intercept=" + this.intercept + ", isKeepingRawLabels=" + this.isKeepingRawLabels + ", threshold=" + this.threshold + '}';
        }

        @Override // org.apache.ignite.ml.inference.json.JSONModel
        public SVMLinearClassificationModel convert() {
            SVMLinearClassificationModel sVMLinearClassificationModel = new SVMLinearClassificationModel();
            sVMLinearClassificationModel.withWeights(VectorUtils.of(this.weights));
            sVMLinearClassificationModel.withIntercept(this.intercept);
            sVMLinearClassificationModel.withRawLabels(this.isKeepingRawLabels);
            sVMLinearClassificationModel.withThreshold(this.threshold);
            return sVMLinearClassificationModel;
        }
    }

    public SVMLinearClassificationModel() {
    }

    public SVMLinearClassificationModel(Vector vector, double d) {
        this.weights = vector;
        this.intercept = d;
    }

    public SVMLinearClassificationModel withRawLabels(boolean z) {
        this.isKeepingRawLabels = z;
        return this;
    }

    public SVMLinearClassificationModel withThreshold(double d) {
        this.threshold = d;
        return this;
    }

    public SVMLinearClassificationModel withWeights(Vector vector) {
        this.weights = vector;
        return this;
    }

    public SVMLinearClassificationModel withIntercept(double d) {
        this.intercept = d;
        return this;
    }

    @Override // org.apache.ignite.ml.inference.Model
    public Double predict(Vector vector) {
        double dot = vector.dot(this.weights) + this.intercept;
        if (this.isKeepingRawLabels) {
            return Double.valueOf(dot);
        }
        return Double.valueOf(dot - this.threshold > 0.0d ? 1.0d : 0.0d);
    }

    public boolean isKeepingRawLabels() {
        return this.isKeepingRawLabels;
    }

    public double threshold() {
        return this.threshold;
    }

    public Vector weights() {
        return this.weights;
    }

    public double intercept() {
        return this.intercept;
    }

    @Override // org.apache.ignite.ml.Exportable
    public <P> void saveModel(Exporter<SVMLinearClassificationModel, P> exporter, P p) {
        exporter.save(this, p);
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        SVMLinearClassificationModel sVMLinearClassificationModel = (SVMLinearClassificationModel) obj;
        return Double.compare(sVMLinearClassificationModel.intercept, this.intercept) == 0 && Double.compare(sVMLinearClassificationModel.threshold, this.threshold) == 0 && Boolean.compare(sVMLinearClassificationModel.isKeepingRawLabels, this.isKeepingRawLabels) == 0 && Objects.equals(this.weights, sVMLinearClassificationModel.weights);
    }

    public int hashCode() {
        return Objects.hash(this.weights, Double.valueOf(this.intercept), Boolean.valueOf(this.isKeepingRawLabels), Double.valueOf(this.threshold));
    }

    public String toString() {
        if (this.weights.size() >= 20) {
            return "SVMModel [weights=" + this.weights + ", intercept=" + this.intercept + ']';
        }
        StringBuilder sb = new StringBuilder();
        int i = 0;
        while (i < this.weights.size()) {
            sb.append(String.format("%.4f", Double.valueOf(Math.abs(this.weights.get(i))))).append("*x").append(i).append((i == this.weights.size() - 1 ? this.intercept : this.weights.get(i + 1)) > 0.0d ? " + " : " - ");
            i++;
        }
        sb.append(String.format("%.4f", Double.valueOf(Math.abs(this.intercept))));
        return sb.toString();
    }

    @Override // org.apache.ignite.ml.IgniteModel
    public String toString(boolean z) {
        return toString();
    }

    public static SVMLinearClassificationModel fromJSON(Path path) {
        try {
            return ((SVMLinearClassificationJSONExportModel) new ObjectMapper().readValue(new File(path.toAbsolutePath().toString()), SVMLinearClassificationJSONExportModel.class)).convert();
        } catch (IOException e) {
            e.printStackTrace();
            return null;
        }
    }

    @Override // org.apache.ignite.ml.inference.json.JSONWritable
    public void toJSON(Path path) {
        ObjectMapper objectMapper = new ObjectMapper();
        try {
            SVMLinearClassificationJSONExportModel sVMLinearClassificationJSONExportModel = new SVMLinearClassificationJSONExportModel(Long.valueOf(System.currentTimeMillis()), "svm_" + UUID.randomUUID().toString(), SVMLinearClassificationModel.class.getSimpleName());
            sVMLinearClassificationJSONExportModel.intercept = this.intercept;
            sVMLinearClassificationJSONExportModel.isKeepingRawLabels = this.isKeepingRawLabels;
            sVMLinearClassificationJSONExportModel.threshold = this.threshold;
            sVMLinearClassificationJSONExportModel.weights = this.weights.asArray();
            objectMapper.writeValue(new File(path.toAbsolutePath().toString()), sVMLinearClassificationJSONExportModel);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
