package org.apache.mahout.flinkbindings.blas;

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.api.scala.DataSet;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.shaded.com.google.common.collect.Lists;
import org.apache.flink.util.Collector;
import org.apache.mahout.flinkbindings.FlinkDistributedContext;
import org.apache.mahout.flinkbindings.drm.BlockifiedFlinkDrm;
import org.apache.mahout.flinkbindings.drm.FlinkDrm;
import org.apache.mahout.math.DenseSymmetricMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.drm.logical.OpAtA;
import org.apache.mahout.math.scalabindings.RLikeOps$;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Predef$Ensuring$;
import scala.Tuple2;
import scala.collection.JavaConverters$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Range;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.Buffer;
import scala.collection.mutable.Buffer$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

/* compiled from: FlinkOpAtA.scala */
/* loaded from: input_file:org/apache/mahout/flinkbindings/blas/FlinkOpAtA$.class */
public final class FlinkOpAtA$ {
    public static final FlinkOpAtA$ MODULE$ = null;
    private final String PROPERTY_ATA_MAXINMEMNCOL;
    private final String PROPERTY_ATA_MAXINMEMNCOL_DEFAULT;

    static {
        new FlinkOpAtA$();
    }

    public final String PROPERTY_ATA_MAXINMEMNCOL() {
        return "mahout.math.AtA.maxInMemNCol";
    }

    public final String PROPERTY_ATA_MAXINMEMNCOL_DEFAULT() {
        return "200";
    }

    public <K> FlinkDrm<Object> at_a(OpAtA<K> opAtA, FlinkDrm<K> flinkDrm) {
        int i = new StringOps(Predef$.MODULE$.augmentString(System.getProperty("mahout.math.AtA.maxInMemNCol", "200"))).toInt();
        Predef$Ensuring$.MODULE$.ensuring$extension3(Predef$.MODULE$.any2Ensuring(BoxesRunTime.boxToInteger(i)), new FlinkOpAtA$$anonfun$at_a$1(), new FlinkOpAtA$$anonfun$at_a$2());
        flinkDrm.classTag();
        if (opAtA.ncol() > i) {
            return fat(opAtA, flinkDrm);
        }
        FlinkDistributedContext context = flinkDrm.context();
        return org.apache.mahout.flinkbindings.package$.MODULE$.checkpointedDrmToFlinkDrm(org.apache.mahout.math.drm.package$.MODULE$.drmParallelize(slim(opAtA, flinkDrm), 1, context), BasicTypeInfo.getInfoFor(Integer.TYPE), ClassTag$.MODULE$.Int());
    }

    public <K> Matrix slim(OpAtA<K> opAtA, FlinkDrm<K> flinkDrm) {
        return new DenseSymmetricMatrix((Vector) flinkDrm.asRowWise().ds().mapPartition(new FlinkOpAtA$$anonfun$3(opAtA.ncol()), TypeExtractor.createTypeInfo(Vector.class), ClassTag$.MODULE$.apply(Vector.class)).reduce(new FlinkOpAtA$$anonfun$4()).collect().head());
    }

    public <K> FlinkDrm<Object> fat(OpAtA<K> opAtA, FlinkDrm<K> flinkDrm) {
        final long nrow = opAtA.A().nrow();
        final int ncol = opAtA.A().ncol();
        DataSet<Tuple2<Object, Matrix>> ds = flinkDrm.asBlockified().ds();
        DataSet reduce = ds.map(new MapFunction<Tuple2<Object, Matrix>, Object>() { // from class: org.apache.mahout.flinkbindings.blas.FlinkOpAtA$$anon$7
            public int map(Tuple2<Object, Matrix> tuple2) {
                return 1;
            }

            public /* bridge */ /* synthetic */ Object map(Object obj) {
                return BoxesRunTime.boxToInteger(map((Tuple2<Object, Matrix>) obj));
            }
        }, BasicTypeInfo.getInfoFor(Integer.TYPE), ClassTag$.MODULE$.Int()).reduce(new ReduceFunction<Object>() { // from class: org.apache.mahout.flinkbindings.blas.FlinkOpAtA$$anon$8
            public int reduce(int i, int i2) {
                return i + i2;
            }

            public /* bridge */ /* synthetic */ Object reduce(Object obj, Object obj2) {
                return BoxesRunTime.boxToInteger(reduce(BoxesRunTime.unboxToInt(obj), BoxesRunTime.unboxToInt(obj2)));
            }
        });
        return new BlockifiedFlinkDrm(ds.flatMap(new RichFlatMapFunction<Tuple2<Object, Matrix>, Tuple2<Object, Matrix>>(nrow, ncol) { // from class: org.apache.mahout.flinkbindings.blas.FlinkOpAtA$$anon$1
            private Range[] ranges = null;
            private final long nrow$1;
            private final int ncol$2;

            public Range[] ranges() {
                return this.ranges;
            }

            public void ranges_$eq(Range[] rangeArr) {
                this.ranges = rangeArr;
            }

            public void open(Configuration configuration) {
                ranges_$eq(FlinkOpAtA$.MODULE$.computeEvenSplits(this.ncol$2, FlinkOpAtA$.MODULE$.estimatePartitions(this.nrow$1, this.ncol$2, BoxesRunTime.unboxToInt(getRuntimeContext().getBroadcastVariable("numberOfPartitions").get(0)))));
            }

            public void flatMap(Tuple2<Object, Matrix> tuple2, Collector<Tuple2<Object, Matrix>> collector) {
                Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(ranges()).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).foreach(new FlinkOpAtA$$anon$1$$anonfun$flatMap$1(this, collector, (Matrix) tuple2._2()));
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) {
                flatMap((Tuple2<Object, Matrix>) obj, (Collector<Tuple2<Object, Matrix>>) collector);
            }

            {
                this.nrow$1 = nrow;
                this.ncol$2 = ncol;
            }
        }, new FlinkOpAtA$$anon$5(), ClassTag$.MODULE$.apply(Tuple2.class)).withBroadcastSet(reduce, "numberOfPartitions").groupBy(Predef$.MODULE$.wrapIntArray(new int[]{0})).reduceGroup(new RichGroupReduceFunction<Tuple2<Object, Matrix>, Tuple2<int[], Matrix>>(nrow, ncol) { // from class: org.apache.mahout.flinkbindings.blas.FlinkOpAtA$$anon$3
            private Range[] ranges = null;
            private final long nrow$1;
            private final int ncol$2;

            public Range[] ranges() {
                return this.ranges;
            }

            public void ranges_$eq(Range[] rangeArr) {
                this.ranges = rangeArr;
            }

            public void open(Configuration configuration) {
                ranges_$eq(FlinkOpAtA$.MODULE$.computeEvenSplits(this.ncol$2, FlinkOpAtA$.MODULE$.estimatePartitions(this.nrow$1, this.ncol$2, BoxesRunTime.unboxToInt(getRuntimeContext().getBroadcastVariable("numberOfPartitions").get(0)))));
            }

            public void reduce(Iterable<Tuple2<Object, Matrix>> iterable, Collector<Tuple2<int[], Matrix>> collector) {
                Buffer buffer = (Buffer) JavaConverters$.MODULE$.asScalaBufferConverter(Lists.newArrayList(iterable)).asScala();
                Tuple2 tuple2 = (Tuple2) buffer.head();
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                int _1$mcI$sp = tuple2._1$mcI$sp();
                Matrix matrix = (Matrix) ((TraversableOnce) buffer.map(new FlinkOpAtA$$anon$3$$anonfun$5(this), Buffer$.MODULE$.canBuildFrom())).reduce(new FlinkOpAtA$$anon$3$$anonfun$6(this));
                collector.collect(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.any2ArrowAssoc((int[]) Array$.MODULE$.tabulate(RLikeOps$.MODULE$.m2mOps(matrix).nrow(), new FlinkOpAtA$$anon$3$$anonfun$1(this, ranges()[_1$mcI$sp].start()), ClassTag$.MODULE$.Int())), matrix));
            }

            {
                this.nrow$1 = nrow;
                this.ncol$2 = ncol;
            }
        }, new FlinkOpAtA$$anon$6(), ClassTag$.MODULE$.apply(Tuple2.class)).withBroadcastSet(reduce, "numberOfPartitions"), ncol, BasicTypeInfo.getInfoFor(Integer.TYPE), ClassTag$.MODULE$.Int());
    }

    public int estimatePartitions(long j, int i, int i2) {
        int max$extension = RichInt$.MODULE$.max$extension(Predef$.MODULE$.intWrapper((int) scala.math.package$.MODULE$.round((j * i) / ((j * i) / i2))), 1);
        return ((long) max$extension) > j ? (int) j : max$extension;
    }

    public Range[] computeEvenSplits(long j, int i) {
        Predef$.MODULE$.require(((long) i) <= j, new FlinkOpAtA$$anonfun$computeEvenSplits$1());
        Predef$.MODULE$.require(j >= 1);
        Predef$.MODULE$.require(i >= 1);
        return (Range[]) ((IndexedSeq) RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(0), i).map(new FlinkOpAtA$$anonfun$2(org.apache.mahout.math.drm.package$.MODULE$.safeToNonNegInt(j / i), org.apache.mahout.math.drm.package$.MODULE$.safeToNonNegInt(j % i)), IndexedSeq$.MODULE$.canBuildFrom())).sliding(2).map(new FlinkOpAtA$$anonfun$7()).toArray(ClassTag$.MODULE$.apply(Range.class));
    }

    private FlinkOpAtA$() {
        MODULE$ = this;
    }
}
