/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.sql.catalyst.expressions;

import java.util.Arrays;
import java.util.Iterator;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.unsafe.memory.TaskMemoryManager;

public final class UnsafeFixedWidthAggregationMap {
    private final byte[] emptyAggregationBuffer;
    private final StructType aggregationBufferSchema;
    private final StructType groupingKeySchema;
    private final UnsafeRowConverter groupingKeyToUnsafeRowConverter;
    private final BytesToBytesMap map;
    private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
    private byte[] groupingKeyConversionScratchSpace = new byte[8192];
    private final boolean enablePerfMetrics;

    public static boolean supportsGroupKeySchema(StructType schema) {
        for (StructField field : schema.fields()) {
            if (UnsafeRow.readableFieldTypes.contains(field.dataType())) continue;
            return false;
        }
        return true;
    }

    public static boolean supportsAggregationBufferSchema(StructType schema) {
        for (StructField field : schema.fields()) {
            if (UnsafeRow.settableFieldTypes.contains(field.dataType())) continue;
            return false;
        }
        return true;
    }

    public UnsafeFixedWidthAggregationMap(Row emptyAggregationBuffer, StructType aggregationBufferSchema, StructType groupingKeySchema, TaskMemoryManager memoryManager, int initialCapacity, boolean enablePerfMetrics) {
        this.emptyAggregationBuffer = UnsafeFixedWidthAggregationMap.convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema);
        this.aggregationBufferSchema = aggregationBufferSchema;
        this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema);
        this.groupingKeySchema = groupingKeySchema;
        this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
        this.enablePerfMetrics = enablePerfMetrics;
    }

    private static byte[] convertToUnsafeRow(Row javaRow, StructType schema) {
        UnsafeRowConverter converter = new UnsafeRowConverter(schema);
        byte[] unsafeRow = new byte[converter.getSizeRequirement(javaRow)];
        int writtenLength = converter.writeRow(javaRow, unsafeRow, PlatformDependent.BYTE_ARRAY_OFFSET);
        assert (writtenLength == unsafeRow.length) : "Size requirement calculation was wrong!";
        return unsafeRow;
    }

    public UnsafeRow getAggregationBuffer(Row groupingKey) {
        int groupingKeySize = this.groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
        if (groupingKeySize > this.groupingKeyConversionScratchSpace.length) {
            this.groupingKeyConversionScratchSpace = new byte[groupingKeySize];
        } else {
            Arrays.fill(this.groupingKeyConversionScratchSpace, 0, groupingKeySize, (byte)0);
        }
        int actualGroupingKeySize = this.groupingKeyToUnsafeRowConverter.writeRow(groupingKey, this.groupingKeyConversionScratchSpace, PlatformDependent.BYTE_ARRAY_OFFSET);
        assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
        BytesToBytesMap.Location loc = this.map.lookup((Object)this.groupingKeyConversionScratchSpace, (long)PlatformDependent.BYTE_ARRAY_OFFSET, groupingKeySize);
        if (!loc.isDefined()) {
            loc.putNewKey((Object)this.groupingKeyConversionScratchSpace, (long)PlatformDependent.BYTE_ARRAY_OFFSET, groupingKeySize, (Object)this.emptyAggregationBuffer, (long)PlatformDependent.BYTE_ARRAY_OFFSET, this.emptyAggregationBuffer.length);
        }
        MemoryLocation address = loc.getValueAddress();
        this.currentAggregationBuffer.pointTo(address.getBaseObject(), address.getBaseOffset(), this.aggregationBufferSchema.length(), this.aggregationBufferSchema);
        return this.currentAggregationBuffer;
    }

    public Iterator<MapEntry> iterator() {
        return new Iterator<MapEntry>(){
            private final MapEntry entry = new MapEntry();
            private final Iterator<BytesToBytesMap.Location> mapLocationIterator = UnsafeFixedWidthAggregationMap.access$100(UnsafeFixedWidthAggregationMap.this).iterator();

            @Override
            public boolean hasNext() {
                return this.mapLocationIterator.hasNext();
            }

            @Override
            public MapEntry next() {
                BytesToBytesMap.Location loc = this.mapLocationIterator.next();
                MemoryLocation keyAddress = loc.getKeyAddress();
                MemoryLocation valueAddress = loc.getValueAddress();
                this.entry.key.pointTo(keyAddress.getBaseObject(), keyAddress.getBaseOffset(), UnsafeFixedWidthAggregationMap.this.groupingKeySchema.length(), UnsafeFixedWidthAggregationMap.this.groupingKeySchema);
                this.entry.value.pointTo(valueAddress.getBaseObject(), valueAddress.getBaseOffset(), UnsafeFixedWidthAggregationMap.this.aggregationBufferSchema.length(), UnsafeFixedWidthAggregationMap.this.aggregationBufferSchema);
                return this.entry;
            }

            @Override
            public void remove() {
                throw new UnsupportedOperationException();
            }
        };
    }

    public void free() {
        this.map.free();
    }

    public void printPerfMetrics() {
        if (!this.enablePerfMetrics) {
            throw new IllegalStateException("Perf metrics not enabled");
        }
        System.out.println("Average probes per lookup: " + this.map.getAverageProbesPerLookup());
        System.out.println("Number of hash collisions: " + this.map.getNumHashCollisions());
        System.out.println("Time spent resizing (ns): " + this.map.getTimeSpentResizingNs());
        System.out.println("Total memory consumption (bytes): " + this.map.getTotalMemoryConsumption());
    }

    static /* synthetic */ BytesToBytesMap access$100(UnsafeFixedWidthAggregationMap x0) {
        return x0.map;
    }

    public static class MapEntry {
        public final UnsafeRow key = new UnsafeRow();
        public final UnsafeRow value = new UnsafeRow();

        private MapEntry() {
        }
    }
}

