package org.apache.ignite.ml.tree.randomforest.data.statistics;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.tree.randomforest.data.NodeId;
import org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel;
import org.apache.ignite.ml.tree.randomforest.data.TreeNode;

/* loaded from: input_file:org/apache/ignite/ml/tree/randomforest/data/statistics/LeafValuesComputer.class */
public abstract class LeafValuesComputer<T> implements Serializable {
    private static final long serialVersionUID = -429848953091775832L;

    public void setValuesForLeaves(ArrayList<RandomForestTreeModel> arrayList, Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
        Map map = (Map) arrayList.stream().flatMap(randomForestTreeModel -> {
            return randomForestTreeModel.leafs().stream();
        }).collect(Collectors.toMap((v0) -> {
            return v0.getId();
        }, Function.identity()));
        Map map2 = (Map) dataset.compute(bootstrappedDatasetPartition -> {
            return computeLeafsStatisticsInPartition(arrayList, map, bootstrappedDatasetPartition);
        }, this::mergeLeafStatistics);
        map.forEach((nodeId, treeNode) -> {
            Object obj = map2.get(nodeId);
            if (obj != null) {
                treeNode.setVal(computeLeafValue(obj));
            }
        });
    }

    private Map<NodeId, T> computeLeafsStatisticsInPartition(ArrayList<RandomForestTreeModel> arrayList, Map<NodeId, TreeNode> map, BootstrappedDatasetPartition bootstrappedDatasetPartition) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < arrayList.size(); i++) {
            int i2 = i;
            bootstrappedDatasetPartition.forEach(bootstrappedVector -> {
                NodeId predictNextNodeKey = ((RandomForestTreeModel) arrayList.get(i2)).getRootNode().predictNextNodeKey(bootstrappedVector.features());
                if (!map.containsKey(predictNextNodeKey)) {
                    throw new IllegalStateException();
                }
                if (!hashMap.containsKey(predictNextNodeKey)) {
                    hashMap.put(predictNextNodeKey, createLeafStatsAggregator(i2));
                }
                addElementToLeafStatistic(hashMap.get(predictNextNodeKey), bootstrappedVector, i2);
            });
        }
        return hashMap;
    }

    private Map<NodeId, T> mergeLeafStatistics(Map<NodeId, T> map, Map<NodeId, T> map2) {
        if (map == null) {
            return map2;
        }
        if (map2 == null) {
            return map;
        }
        HashSet<NodeId> hashSet = new HashSet(map.keySet());
        hashSet.addAll(map2.keySet());
        for (NodeId nodeId : hashSet) {
            if (!map.containsKey(nodeId)) {
                map.put(nodeId, map2.get(nodeId));
            } else if (map2.containsKey(nodeId)) {
                map.put(nodeId, mergeLeafStats(map.get(nodeId), map2.get(nodeId)));
            }
        }
        return map;
    }

    protected abstract void addElementToLeafStatistic(T t, BootstrappedVector bootstrappedVector, int i);

    protected abstract T mergeLeafStats(T t, T t2);

    protected abstract T createLeafStatsAggregator(int i);

    protected abstract double computeLeafValue(T t);

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1289224158:
                if (implMethodName.equals("lambda$setValuesForLeaves$50693fdb$1")) {
                    z = false;
                    break;
                }
                break;
            case 371637625:
                if (implMethodName.equals("mergeLeafStatistics")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/tree/randomforest/data/statistics/LeafValuesComputer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/ArrayList;Ljava/util/Map;Lorg/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetPartition;)Ljava/util/Map;")) {
                    LeafValuesComputer leafValuesComputer = (LeafValuesComputer) serializedLambda.getCapturedArg(0);
                    ArrayList arrayList = (ArrayList) serializedLambda.getCapturedArg(1);
                    Map map = (Map) serializedLambda.getCapturedArg(2);
                    return bootstrappedDatasetPartition -> {
                        return computeLeafsStatisticsInPartition(arrayList, map, bootstrappedDatasetPartition);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/tree/randomforest/data/statistics/LeafValuesComputer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Map;Ljava/util/Map;)Ljava/util/Map;")) {
                    LeafValuesComputer leafValuesComputer2 = (LeafValuesComputer) serializedLambda.getCapturedArg(0);
                    return leafValuesComputer2::mergeLeafStatistics;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
