/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.plugin.flink.network;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
import org.apache.celeborn.common.network.protocol.RequestMessage;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.protocol.TransportableError;
import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.protocol.PbTransportableError;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
import org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReadClientHandler
extends BaseMessageHandler {
    private static Logger logger = LoggerFactory.getLogger(ReadClientHandler.class);
    private ConcurrentHashMap<Long, Consumer<RequestMessage>> streamHandlers = JavaUtils.newConcurrentHashMap();
    private ConcurrentHashMap<Long, TransportClient> streamClients = JavaUtils.newConcurrentHashMap();

    public void registerHandler(long streamId, Consumer<RequestMessage> handle, TransportClient client) {
        this.streamHandlers.put(streamId, handle);
        this.streamClients.put(streamId, client);
    }

    public void removeHandler(long streamId) {
        this.streamHandlers.remove(streamId);
        this.streamClients.remove(streamId);
    }

    private void processMessageInternal(long streamId, RequestMessage msg) {
        Consumer<RequestMessage> handler = this.streamHandlers.get(streamId);
        if (handler != null) {
            logger.debug("received streamId: {}, msg :{}", (Object)streamId, (Object)msg);
            handler.accept(msg);
        } else {
            if (msg instanceof ReadData) {
                ((ReadData)msg).getFlinkBuffer().release();
            }
            if (!(msg instanceof BufferStreamEnd)) {
                logger.warn("Unexpected streamId received: {}, msg: {}", (Object)streamId, (Object)msg);
            }
        }
    }

    @Override
    public void receive(TransportClient client, RequestMessage msg, RpcResponseCallback callback) {
        this.receive(client, msg);
    }

    @Override
    public void receive(TransportClient client, RequestMessage msg) {
        switch (msg.type()) {
            case READ_DATA: {
                ReadData readData = (ReadData)msg;
                this.processMessageInternal(readData.getStreamId(), readData);
                break;
            }
            case SUBPARTITION_READ_DATA: {
                SubPartitionReadData subPartitionReadData = (SubPartitionReadData)msg;
                this.processMessageInternal(subPartitionReadData.getStreamId(), subPartitionReadData);
                break;
            }
            case BACKLOG_ANNOUNCEMENT: {
                BacklogAnnouncement backlogAnnouncement = (BacklogAnnouncement)msg;
                this.processMessageInternal(backlogAnnouncement.getStreamId(), backlogAnnouncement);
                break;
            }
            case TRANSPORTABLE_ERROR: {
                TransportableError transportableError = (TransportableError)msg;
                logger.warn("Received TransportableError from worker {} with content {}", (Object)client.getSocketAddress().toString(), (Object)transportableError.getErrorMessage());
                this.processMessageInternal(transportableError.getStreamId(), transportableError);
                break;
            }
            case BUFFER_STREAM_END: {
                BufferStreamEnd streamEnd = (BufferStreamEnd)msg;
                this.processMessageInternal(streamEnd.getStreamId(), streamEnd);
                break;
            }
            case RPC_REQUEST: {
                try {
                    TransportMessage transportMessage = TransportMessage.fromByteBuffer(msg.body().nioByteBuffer());
                    switch (transportMessage.getMessageTypeValue()) {
                        case 59: {
                            this.receive(client, BacklogAnnouncement.fromProto((PbBacklogAnnouncement)transportMessage.getParsedPayload()));
                            break;
                        }
                        case 60: {
                            this.receive(client, BufferStreamEnd.fromProto((PbBufferStreamEnd)transportMessage.getParsedPayload()));
                            break;
                        }
                        case 64: {
                            this.receive(client, TransportableError.fromProto((PbTransportableError)transportMessage.getParsedPayload()));
                        }
                    }
                }
                catch (IOException e) {
                    logger.warn("Failed to process RpcRequest message {}. ", (Object)msg, (Object)e);
                }
                break;
            }
            case ONE_WAY_MESSAGE: {
                break;
            }
            default: {
                logger.error("Unexpected msg type {} content {}", (Object)msg.type(), (Object)msg);
            }
        }
    }

    @Override
    public boolean checkRegistered() {
        return true;
    }

    @Override
    public void channelInactive(TransportClient client) {
        this.streamClients.forEach((streamId, savedClient) -> {
            if (savedClient == client) {
                String message = "Client " + client.getSocketAddress() + " is lost, notify related stream " + streamId;
                logger.warn(message);
                this.processMessageInternal((long)streamId, new TransportableError((long)streamId, message.getBytes(StandardCharsets.UTF_8)));
            }
        });
    }

    @Override
    public void exceptionCaught(Throwable cause, TransportClient client) {
        logger.warn("exception caught {}", (Object)client.getSocketAddress(), (Object)cause);
    }

    public void close() {
        this.streamHandlers.clear();
        for (TransportClient value : this.streamClients.values()) {
            value.close();
        }
        this.streamClients.clear();
    }
}

