/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.plugin.flink.network;

import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
import org.apache.celeborn.common.network.protocol.Message;
import org.apache.celeborn.common.network.util.FrameDecoder;
import org.apache.celeborn.plugin.flink.network.MessageDecoderExt;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
import org.apache.celeborn.shaded.io.netty.channel.ChannelHandlerContext;
import org.apache.celeborn.shaded.io.netty.channel.ChannelInboundHandlerAdapter;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TransportFrameDecoderWithBufferSupplier
extends ChannelInboundHandlerAdapter
implements FrameDecoder {
    public static final Logger logger = LoggerFactory.getLogger(TransportFrameDecoderWithBufferSupplier.class);
    private int msgSize = -1;
    private int bodySize = -1;
    private Message.Type curType = Message.Type.UNKNOWN_TYPE;
    private ByteBuf headerBuf = Unpooled.buffer((int)9, (int)9);
    private org.apache.celeborn.shaded.io.netty.buffer.CompositeByteBuf bodyBuf = null;
    private ByteBuf externalBuf = null;
    private final ByteBuf msgBuf = Unpooled.buffer((int)8);
    private Message curMsg = null;
    private int remainingSize = -1;
    private int totalReadBytes = 0;
    private int largeBufferHeaderRemainingBytes = -1;
    private boolean isReadingLargeBuffer = false;
    private ByteBuf largeBufferHeaderBuffer;
    public static final int DISABLE_LARGE_BUFFER_SPLIT_SIZE = -1;
    private final int bufferSizeBytes;
    private final ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers;

    public TransportFrameDecoderWithBufferSupplier(ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers) {
        this(bufferSuppliers, -1);
    }

    public TransportFrameDecoderWithBufferSupplier(ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers, int bufferSizeBytes) {
        this.bufferSuppliers = bufferSuppliers;
        this.bufferSizeBytes = bufferSizeBytes;
    }

    private int copyByteBuf(org.apache.celeborn.shaded.io.netty.buffer.ByteBuf source, ByteBuf target, int targetSize) {
        int bytes = Math.min(source.readableBytes(), targetSize - target.readableBytes());
        target.writeBytes(source.readSlice(bytes).nioBuffer());
        return bytes;
    }

    private void decodeHeader(org.apache.celeborn.shaded.io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
        this.copyByteBuf(buf, this.headerBuf, 9);
        if (!this.headerBuf.isWritable()) {
            this.msgSize = this.headerBuf.readInt();
            if (this.msgBuf.capacity() < this.msgSize) {
                this.msgBuf.capacity(this.msgSize);
            }
            this.msgBuf.clear();
            this.curType = Message.Type.decode(this.headerBuf.nioBuffer());
            this.headerBuf.readByte();
            this.bodySize = this.headerBuf.readInt();
            if (this.bufferSizeBytes != -1 && this.bodySize > this.bufferSizeBytes) {
                this.isReadingLargeBuffer = true;
                this.largeBufferHeaderBuffer = Unpooled.buffer((int)22, (int)22);
                this.largeBufferHeaderRemainingBytes = 22;
            }
            this.decodeMsg(buf, ctx);
        }
    }

    private void decodeMsg(org.apache.celeborn.shaded.io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
        if (this.msgBuf.readableBytes() < this.msgSize) {
            this.copyByteBuf(buf, this.msgBuf, this.msgSize);
        }
        if (this.msgBuf.readableBytes() == this.msgSize) {
            this.curMsg = MessageDecoderExt.decode(this.curType, this.msgBuf, false);
            if (this.bodySize <= 0) {
                ctx.fireChannelRead(this.curMsg);
                this.clear();
            }
        }
    }

    private org.apache.celeborn.shaded.io.netty.buffer.ByteBuf decodeBody(org.apache.celeborn.shaded.io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
        org.apache.celeborn.shaded.io.netty.buffer.ByteBuf next;
        int remaining;
        if (this.bodyBuf == null) {
            if (buf.readableBytes() >= this.bodySize) {
                org.apache.celeborn.shaded.io.netty.buffer.ByteBuf body = buf.retain().readSlice(this.bodySize);
                this.curMsg.setBody(body);
                ctx.fireChannelRead(this.curMsg);
                this.clear();
                return buf;
            }
            this.bodyBuf = buf.alloc().compositeBuffer(Integer.MAX_VALUE);
        }
        if ((remaining = this.bodySize - this.bodyBuf.readableBytes()) >= buf.readableBytes()) {
            next = buf;
            buf = null;
        } else {
            next = buf.retain().readSlice(remaining);
        }
        this.bodyBuf.addComponent(next).writerIndex(this.bodyBuf.writerIndex() + next.readableBytes());
        if (this.bodyBuf.readableBytes() == this.bodySize) {
            this.curMsg.setBody(this.bodyBuf);
            ctx.fireChannelRead(this.curMsg);
            this.clear();
        }
        return buf;
    }

    private org.apache.celeborn.shaded.io.netty.buffer.ByteBuf decodeBodyCopyOut(org.apache.celeborn.shaded.io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
        if (this.remainingSize > 0) {
            this.dropUnusedBytes(buf);
            return buf;
        }
        ReadData readData = (ReadData)this.curMsg;
        long streamId = readData.getStreamId();
        if (this.externalBuf == null) {
            Supplier<ByteBuf> supplier = this.bufferSuppliers.get(streamId);
            if (supplier == null) {
                return this.needDropUnusedBytes(streamId, buf);
            }
            try {
                this.externalBuf = supplier.get();
            }
            catch (Exception e) {
                return this.needDropUnusedBytes(streamId, buf);
            }
        }
        if (this.largeBufferHeaderRemainingBytes > 0) {
            int headerReadBytes = this.copyByteBuf(buf, this.largeBufferHeaderBuffer, 22);
            this.largeBufferHeaderRemainingBytes -= headerReadBytes;
            this.totalReadBytes += headerReadBytes;
        } else {
            this.totalReadBytes += this.copyByteBuf(buf, this.externalBuf, this.getTargetDataBufferReadSize());
        }
        if (this.totalReadBytes == this.bodySize) {
            ByteBuf resultByteBuf;
            if (this.largeBufferHeaderBuffer == null) {
                resultByteBuf = this.externalBuf;
            } else {
                CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
                compositeByteBuf.addComponent(true, this.largeBufferHeaderBuffer);
                compositeByteBuf.addComponent(true, this.externalBuf);
                resultByteBuf = compositeByteBuf;
            }
            ((ReadData)this.curMsg).setFlinkBuffer(resultByteBuf);
            ctx.fireChannelRead(this.curMsg);
            this.clear();
        }
        return buf;
    }

    private org.apache.celeborn.shaded.io.netty.buffer.ByteBuf needDropUnusedBytes(long streamId, org.apache.celeborn.shaded.io.netty.buffer.ByteBuf byteBuf) {
        logger.warn("Need drop unused bytes, streamId: {}, bodySize: {}", (Object)streamId, (Object)this.bodySize);
        this.remainingSize = this.bodySize;
        this.dropUnusedBytes(byteBuf);
        return byteBuf;
    }

    private void dropUnusedBytes(org.apache.celeborn.shaded.io.netty.buffer.ByteBuf source) {
        if (source.readableBytes() > 0) {
            if (this.remainingSize > source.readableBytes()) {
                this.remainingSize -= source.readableBytes();
                source.skipBytes(source.readableBytes());
            } else {
                source.skipBytes(this.remainingSize);
                this.clear();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object data) {
        org.apache.celeborn.shaded.io.netty.buffer.ByteBuf nettyBuf = (org.apache.celeborn.shaded.io.netty.buffer.ByteBuf)data;
        try {
            while (nettyBuf != null && nettyBuf.isReadable()) {
                if (this.headerBuf.isWritable()) {
                    this.decodeHeader(nettyBuf, ctx);
                    continue;
                }
                if (this.curMsg == null) {
                    this.decodeMsg(nettyBuf, ctx);
                    continue;
                }
                if (this.bodySize <= 0) continue;
                if (this.curMsg.needCopyOut()) {
                    nettyBuf = this.decodeBodyCopyOut(nettyBuf, ctx);
                    continue;
                }
                nettyBuf = this.decodeBody(nettyBuf, ctx);
            }
        }
        finally {
            if (nettyBuf != null) {
                nettyBuf.release();
            }
        }
    }

    private int getTargetDataBufferReadSize() {
        if (this.isReadingLargeBuffer) {
            return this.bodySize - 22;
        }
        return this.bodySize;
    }

    private void clear() {
        this.externalBuf = null;
        this.curMsg = null;
        this.curType = Message.Type.UNKNOWN_TYPE;
        this.headerBuf.clear();
        this.bodyBuf = null;
        this.bodySize = -1;
        this.remainingSize = -1;
        this.totalReadBytes = 0;
        this.largeBufferHeaderRemainingBytes = -1;
        this.largeBufferHeaderBuffer = null;
        this.isReadingLargeBuffer = false;
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        super.channelInactive(ctx);
    }

    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
        this.clear();
        this.headerBuf.release();
        this.msgBuf.release();
        super.handlerRemoved(ctx);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        super.exceptionCaught(ctx, cause);
    }
}

