/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.shuffle.celeborn;

import java.io.IOException;
import java.util.LinkedList;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Consumer;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.write.DataPusher;
import org.apache.celeborn.client.write.PushTask;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.Utils;
import org.apache.spark.TaskContext;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.SparkOutOfMemoryError;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.TooLargePageException;
import org.apache.spark.shuffle.celeborn.SendBufferPool;
import org.apache.spark.shuffle.celeborn.ShuffleInMemorySorter;
import org.apache.spark.shuffle.celeborn.TaskInterruptedHelper;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UnsafeAlignedOffset;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SortBasedPusher
extends MemoryConsumer {
    private static final Logger logger = LoggerFactory.getLogger(SortBasedPusher.class);
    private static final int UAO_SIZE = UnsafeAlignedOffset.getUaoSize();
    private long peakMemoryUsedBytes;
    private ShuffleInMemorySorter inMemSorter;
    private final LinkedList<MemoryBlock> allocatedPages = new LinkedList();
    private MemoryBlock currentPage = null;
    private long pageCursor = -1L;
    private final ShuffleClient shuffleClient;
    private final TaskContext taskContext;
    private DataPusher dataPusher;
    private final int pushBufferMaxSize;
    private long pushSortMemoryThreshold;
    private final int shuffleId;
    private final int mapId;
    private final int attemptNumber;
    private final int numMappers;
    private final int numPartitions;
    private final Consumer<Integer> afterPush;
    private final LongAdder[] mapStatusLengths;
    private int[] shuffledPartitions = null;
    private int[] inversedShuffledPartitions = null;
    private final SendBufferPool sendBufferPool;
    final MemoryThresholdManager memoryThresholdManager;
    private final boolean useAdaptiveThreshold;
    private final double maxMemoryFactor;

    public SortBasedPusher(TaskMemoryManager memoryManager, ShuffleClient shuffleClient, TaskContext taskContext, int shuffleId, int mapId, int attemptNumber, long taskAttemptId, int numMappers, int numPartitions, CelebornConf conf, Consumer<Integer> afterPush, LongAdder[] mapStatusLengths, long pushSortMemoryThreshold, SendBufferPool sendBufferPool) {
        super(memoryManager, (long)((int)Math.min(0x8000000L, memoryManager.pageSizeBytes())), memoryManager.getTungstenMemoryMode());
        this.shuffleClient = shuffleClient;
        this.taskContext = taskContext;
        this.shuffleId = shuffleId;
        this.mapId = mapId;
        this.attemptNumber = attemptNumber;
        this.numMappers = numMappers;
        this.numPartitions = numPartitions;
        if (conf.clientPushSortRandomizePartitionIdEnabled()) {
            this.shuffledPartitions = new int[numPartitions];
            this.inversedShuffledPartitions = new int[numPartitions];
            JavaUtils.shuffleArray(this.shuffledPartitions, this.inversedShuffledPartitions);
        }
        this.afterPush = afterPush;
        this.mapStatusLengths = mapStatusLengths;
        this.sendBufferPool = sendBufferPool;
        try {
            LinkedBlockingQueue<PushTask> pushTaskQueue = sendBufferPool.acquirePushTaskQueue();
            this.dataPusher = new DataPusher(shuffleId, mapId, attemptNumber, taskAttemptId, numMappers, numPartitions, conf, shuffleClient, pushTaskQueue, afterPush, mapStatusLengths);
        }
        catch (InterruptedException e) {
            TaskInterruptedHelper.throwTaskKillException();
        }
        this.pushBufferMaxSize = conf.clientPushBufferMaxSize();
        this.useAdaptiveThreshold = conf.clientPushSortUseAdaptiveMemoryThreshold();
        this.maxMemoryFactor = conf.clientPushSortMaxMemoryFactor();
        this.pushSortMemoryThreshold = pushSortMemoryThreshold;
        this.memoryThresholdManager = new MemoryThresholdManager(this.maxMemoryFactor, conf.clientPushSortSmallPushTolerateFactor());
        int initialSize = Math.min((int)pushSortMemoryThreshold / 8, 0x100000);
        this.inMemSorter = new ShuffleInMemorySorter(this, initialSize);
        this.peakMemoryUsedBytes = this.getMemoryUsage();
    }

    public long pushData(boolean growThreshold) throws IOException {
        ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = this.inMemSorter.getSortedIterator();
        byte[] dataBuf = new byte[this.pushBufferMaxSize];
        int offSet = 0;
        int currentPartition = -1;
        while (sortedRecords.hasNext()) {
            long recordOffsetInPage;
            long recordPointer;
            Object recordPage;
            int recordSize;
            int partition;
            sortedRecords.loadNext();
            int n = partition = this.shuffledPartitions != null ? this.inversedShuffledPartitions[sortedRecords.packedRecordPointer.getPartitionId()] : sortedRecords.packedRecordPointer.getPartitionId();
            if (partition != currentPartition) {
                if (currentPartition == -1) {
                    currentPartition = partition;
                } else {
                    int bytesWritten = this.shuffleClient.mergeData(this.shuffleId, this.mapId, this.attemptNumber, currentPartition, dataBuf, 0, offSet, this.numMappers, this.numPartitions);
                    this.mapStatusLengths[currentPartition].add(bytesWritten);
                    this.afterPush.accept(bytesWritten);
                    this.memoryThresholdManager.updateStats(offSet, offSet == this.pushBufferMaxSize);
                    currentPartition = partition;
                    offSet = 0;
                }
            }
            if (offSet + (recordSize = UnsafeAlignedOffset.getSize((Object)(recordPage = this.taskMemoryManager.getPage(recordPointer = sortedRecords.packedRecordPointer.getRecordPointer())), (long)(recordOffsetInPage = this.taskMemoryManager.getOffsetInPage(recordPointer)))) > dataBuf.length) {
                try {
                    this.dataPusher.addTask(partition, dataBuf, offSet);
                    this.memoryThresholdManager.updateStats(offSet, true);
                }
                catch (InterruptedException e) {
                    TaskInterruptedHelper.throwTaskKillException();
                }
                offSet = 0;
            }
            long recordReadPosition = recordOffsetInPage + (long)UAO_SIZE;
            Platform.copyMemory((Object)recordPage, (long)recordReadPosition, (Object)dataBuf, (long)(Platform.BYTE_ARRAY_OFFSET + offSet), (long)recordSize);
            offSet += recordSize;
        }
        if (offSet > 0) {
            try {
                this.dataPusher.addTask(currentPartition, dataBuf, offSet);
                this.memoryThresholdManager.updateStats(offSet, offSet == this.pushBufferMaxSize);
            }
            catch (InterruptedException e) {
                TaskInterruptedHelper.throwTaskKillException();
            }
        }
        if (growThreshold) {
            this.memoryThresholdManager.growThresholdIfNeeded();
        }
        long freedBytes = this.freeMemory();
        this.inMemSorter.freeMemory();
        this.taskContext.taskMetrics().incMemoryBytesSpilled(freedBytes);
        return freedBytes;
    }

    public boolean insertRecord(Object recordBase, long recordOffset, int recordSize, int partitionId, boolean copySize) throws IOException {
        int required = copySize ? recordSize + 4 + UAO_SIZE : recordSize + UAO_SIZE;
        long threshold = this.pushSortMemoryThreshold;
        if (this.getUsed() > threshold && this.pageCursor + (long)required > this.currentPage.getBaseOffset() + this.currentPage.size()) {
            logger.debug("Memory used {} exceeds threshold {}, need to trigger push. currentPage size: {}", new Object[]{Utils.bytesToString(this.getUsed()), Utils.bytesToString(this.pushSortMemoryThreshold), Utils.bytesToString(this.currentPage.size())});
            return false;
        }
        this.allocateMemoryForRecordIfNecessary(required);
        assert (this.currentPage != null);
        Object base = this.currentPage.getBaseObject();
        long recordAddress = this.taskMemoryManager.encodePageNumberAndOffset(this.currentPage, this.pageCursor);
        if (copySize) {
            UnsafeAlignedOffset.putSize((Object)base, (long)this.pageCursor, (int)(recordSize + 4));
            this.pageCursor += (long)UAO_SIZE;
            Platform.putInt((Object)base, (long)this.pageCursor, (int)Integer.reverseBytes(recordSize));
            this.pageCursor += 4L;
            Platform.copyMemory((Object)recordBase, (long)recordOffset, (Object)base, (long)this.pageCursor, (long)recordSize);
            this.pageCursor += (long)recordSize;
        } else {
            UnsafeAlignedOffset.putSize((Object)base, (long)this.pageCursor, (int)recordSize);
            this.pageCursor += (long)UAO_SIZE;
            Platform.copyMemory((Object)recordBase, (long)recordOffset, (Object)base, (long)this.pageCursor, (long)recordSize);
            this.pageCursor += (long)recordSize;
        }
        if (this.shuffledPartitions != null) {
            this.inMemSorter.insertRecord(recordAddress, this.shuffledPartitions[partitionId]);
        } else {
            this.inMemSorter.insertRecord(recordAddress, partitionId);
        }
        return true;
    }

    private void growPointerArrayIfNecessary(long required) throws IOException {
        assert (this.inMemSorter != null);
        if (!this.inMemSorter.hasSpaceForAnotherRecord()) {
            if (this.inMemSorter.numRecords() <= 0) {
                LongArray array = this.allocateArray(this.inMemSorter.getInitialSize());
                this.inMemSorter.expandPointerArray(array);
                return;
            }
            long used = this.inMemSorter.getMemoryUsage();
            long requestedBytes = used / 8L * 2L;
            int allocateMemoryRetryCount = 0;
            int maxMemoryAllocationRetry = 3;
            LongArray array = null;
            boolean continueRetry = true;
            while (allocateMemoryRetryCount < maxMemoryAllocationRetry && continueRetry) {
                try {
                    logger.info("asking for " + requestedBytes + " more bytes to accommodate more records");
                    array = this.allocateArray(requestedBytes);
                    continueRetry = false;
                }
                catch (TooLargePageException e) {
                    logger.info("Pushdata in growPointerArrayIfNecessary, memory used {}", (Object)Utils.bytesToString(this.getUsed()));
                    this.pushData(true);
                    continueRetry = false;
                }
                catch (SparkOutOfMemoryError rethrow) {
                    ++allocateMemoryRetryCount;
                    if (this.inMemSorter.numRecords() > 0) {
                        if (allocateMemoryRetryCount == maxMemoryAllocationRetry) {
                            logger.error("OOM, unable to grow the pointer array");
                            throw rethrow;
                        }
                        long oldReq = requestedBytes;
                        requestedBytes = Math.max((long)((double)requestedBytes * 0.5), required);
                        logger.warn("cannot allocate " + oldReq + " bytes, cut the request to " + requestedBytes + " bytes and retry", (Throwable)rethrow);
                        this.pushData(true);
                        continue;
                    }
                    continueRetry = false;
                }
            }
            if (this.inMemSorter.numRecords() <= 0) {
                if (array != null) {
                    this.freeArray(array);
                }
                array = this.allocateArray(this.inMemSorter.getInitialSize());
            }
            this.inMemSorter.expandPointerArray(array);
        }
    }

    private void acquireNewPageIfNecessary(int required) {
        if (this.currentPage == null || this.pageCursor + (long)required > this.currentPage.getBaseOffset() + this.currentPage.size()) {
            this.currentPage = this.allocatePage(required);
            this.pageCursor = this.currentPage.getBaseOffset();
            this.allocatedPages.add(this.currentPage);
        }
    }

    private void allocateMemoryForRecordIfNecessary(int required) throws IOException {
        this.growPointerArrayIfNecessary(required);
        this.acquireNewPageIfNecessary(required);
        this.growPointerArrayIfNecessary(required);
    }

    public long spill(long l, MemoryConsumer memoryConsumer) throws IOException {
        logger.warn("SortBasedPusher not support spill yet");
        return 0L;
    }

    private long getMemoryUsage() {
        long totalPageSize = 0L;
        for (MemoryBlock page : this.allocatedPages) {
            totalPageSize += page.size();
        }
        return (this.inMemSorter == null ? 0L : this.inMemSorter.getMemoryUsage()) + totalPageSize;
    }

    private void updatePeakMemoryUsed() {
        long mem = this.getMemoryUsage();
        if (mem > this.peakMemoryUsedBytes) {
            this.peakMemoryUsedBytes = mem;
        }
    }

    long getPeakMemoryUsedBytes() {
        this.updatePeakMemoryUsed();
        return this.peakMemoryUsedBytes;
    }

    private long freeMemory() {
        this.updatePeakMemoryUsed();
        long memoryFreed = 0L;
        for (MemoryBlock block : this.allocatedPages) {
            memoryFreed += block.size();
            this.freePage(block);
        }
        this.allocatedPages.clear();
        this.currentPage = null;
        this.pageCursor = 0L;
        return memoryFreed;
    }

    public void cleanupResources() {
        long freedBytes = this.freeMemory();
        if (this.inMemSorter != null) {
            this.inMemSorter.freeMemory();
            this.inMemSorter = null;
        }
        this.taskContext.taskMetrics().incMemoryBytesSpilled(freedBytes);
    }

    public long getPushSortMemoryThreshold() {
        return this.pushSortMemoryThreshold;
    }

    public void close(boolean throwTaskKilledOnInterruption) throws IOException {
        block2: {
            this.cleanupResources();
            try {
                this.dataPusher.waitOnTermination();
                this.sendBufferPool.returnPushTaskQueue(this.dataPusher.getAndResetIdleQueue());
            }
            catch (InterruptedException e) {
                if (!throwTaskKilledOnInterruption) break block2;
                TaskInterruptedHelper.throwTaskKillException();
            }
        }
    }

    public long getUsed() {
        return super.getUsed();
    }

    class MemoryThresholdManager {
        private final long maxMemoryThresholdInBytes;
        private final double smallPushTolerateFactor;
        long pushedCount = 0L;
        long pushedMemorySizeInBytes = 0L;
        long expectedPushedCount = 0L;
        long expectedPushedBytes = 0L;

        MemoryThresholdManager(double maxMemoryFactor, double smallPushTolerateFactor) {
            this.maxMemoryThresholdInBytes = (long)((double)Runtime.getRuntime().maxMemory() * maxMemoryFactor);
            this.smallPushTolerateFactor = smallPushTolerateFactor;
        }

        private boolean shouldGrow() {
            boolean enoughSpace = SortBasedPusher.this.pushSortMemoryThreshold <= this.maxMemoryThresholdInBytes;
            double expectedPushSize = 9.223372036854776E18;
            if (this.expectedPushedCount != 0L) {
                expectedPushSize = (double)this.expectedPushedBytes * 1.0 / (double)this.expectedPushedCount;
            }
            boolean tooManyPushed = (double)this.pushedMemorySizeInBytes * 1.0 / (double)this.pushedCount * (1.0 + this.smallPushTolerateFactor) < expectedPushSize;
            return SortBasedPusher.this.useAdaptiveThreshold && enoughSpace && tooManyPushed;
        }

        public void growThresholdIfNeeded() {
            if (this.shouldGrow()) {
                long oldThreshold = SortBasedPusher.this.pushSortMemoryThreshold;
                SortBasedPusher.this.pushSortMemoryThreshold = Math.min(SortBasedPusher.this.pushSortMemoryThreshold * 2L, this.maxMemoryThresholdInBytes);
                logger.info("grow memory threshold from " + Utils.bytesToString(oldThreshold) + " to " + Utils.bytesToString(SortBasedPusher.this.pushSortMemoryThreshold));
                this.pushedCount = 0L;
                this.pushedMemorySizeInBytes = 0L;
                this.expectedPushedBytes = 0L;
                this.expectedPushedCount = 0L;
            }
        }

        public void updateStats(long pushedBytes, boolean updateExpected) {
            this.pushedMemorySizeInBytes += pushedBytes;
            ++this.pushedCount;
            if (updateExpected) {
                this.expectedPushedBytes += pushedBytes;
                ++this.expectedPushedCount;
            }
        }
    }
}

