/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shenyu.plugin.mcp.server.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import java.io.IOException;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.reactive.function.server.ServerResponse;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;

public class ShenyuSseServerTransportProvider
implements McpServerTransportProvider {
    private static final Logger LOGGER = LoggerFactory.getLogger(ShenyuSseServerTransportProvider.class);
    private static final String MESSAGE_EVENT_TYPE = "message";
    private static final String ENDPOINT_EVENT_TYPE = "endpoint";
    private static final String DEFAULT_SSE_ENDPOINT = "/sse";
    private static final String DEFAULT_BASE_URL = "";
    private final ObjectMapper objectMapper;
    private final String baseUrl;
    private final String messageEndpoint;
    private final String sseEndpoint;
    private McpServerSession.Factory sessionFactory;
    private final ConcurrentHashMap<String, McpServerSession> sessions = new ConcurrentHashMap();
    private volatile boolean isClosing;

    public ShenyuSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
        this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
    }

    public ShenyuSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
        this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint);
    }

    public ShenyuSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) {
        Assert.notNull((Object)objectMapper, (String)"ObjectMapper must not be null");
        Assert.notNull((Object)baseUrl, (String)"Message base path must not be null");
        Assert.notNull((Object)messageEndpoint, (String)"Message endpoint must not be null");
        Assert.notNull((Object)sseEndpoint, (String)"SSE endpoint must not be null");
        this.objectMapper = objectMapper;
        this.baseUrl = baseUrl;
        this.messageEndpoint = messageEndpoint;
        this.sseEndpoint = sseEndpoint;
    }

    public static Builder builder() {
        return new Builder();
    }

    public void setSessionFactory(McpServerSession.Factory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    public Mono<Void> notifyClients(String method, Object params) {
        if (this.sessions.isEmpty()) {
            LOGGER.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        LOGGER.debug("Attempting to broadcast message to {} active sessions", (Object)this.sessions.size());
        return Flux.fromIterable(this.sessions.values()).flatMap(session -> session.sendNotification(method, params).doOnError(e -> LOGGER.error("Failed to send message to session {}: {}", (Object)session.getId(), (Object)e.getMessage())).onErrorComplete()).then();
    }

    public Mono<Void> closeGracefully() {
        return Flux.fromIterable(this.sessions.values()).doFirst(() -> LOGGER.debug("Initiating graceful shutdown with {} active sessions", (Object)this.sessions.size())).flatMap(McpServerSession::closeGracefully).then();
    }

    public Mono<ServerResponse> handleSseConnection(ServerRequest request) {
        if (this.isClosing) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.SERVICE_UNAVAILABLE).bodyValue((Object)"Server is shutting down");
        }
        if (Objects.isNull(this.sessionFactory)) {
            LOGGER.error("SessionFactory is null - MCP server not properly initialized");
            return ServerResponse.status((HttpStatusCode)HttpStatus.INTERNAL_SERVER_ERROR).bodyValue((Object)"MCP server not properly initialized");
        }
        return ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM).body((Object)Flux.create(sink -> {
            try {
                WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport((FluxSink<ServerSentEvent<?>>)sink);
                McpServerSession session = this.sessionFactory.create((McpServerTransport)sessionTransport);
                String sessionId = session.getId();
                LOGGER.debug("Created new SSE connection for session: {}", (Object)sessionId);
                this.sessions.put(sessionId, session);
                LOGGER.debug("Sending initial endpoint event to session: {}", (Object)sessionId);
                String endpointUrl = this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId;
                LOGGER.debug("Endpoint URL: {}", (Object)endpointUrl);
                ServerSentEvent endpointEvent = ServerSentEvent.builder().event(ENDPOINT_EVENT_TYPE).data((Object)endpointUrl).build();
                sink.next((Object)endpointEvent);
                sink.onCancel(() -> {
                    LOGGER.debug("Session {} cancelled", (Object)sessionId);
                    this.sessions.remove(sessionId);
                });
            }
            catch (Exception e) {
                LOGGER.error("Error creating SSE session", (Throwable)e);
                sink.error((Throwable)e);
            }
        }), ServerSentEvent.class);
    }

    public Flux<ServerSentEvent<?>> createSseFlux(ServerRequest request) {
        if (this.isClosing) {
            return Flux.error((Throwable)new RuntimeException("Server is shutting down"));
        }
        if (Objects.isNull(this.sessionFactory)) {
            LOGGER.error("SessionFactory is null - MCP server not properly initialized");
            return Flux.error((Throwable)new RuntimeException("MCP server not properly initialized"));
        }
        return Flux.create(sink -> {
            try {
                WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport((FluxSink<ServerSentEvent<?>>)sink);
                McpServerSession session = this.sessionFactory.create((McpServerTransport)sessionTransport);
                String sessionId = session.getId();
                LOGGER.info("Created new SSE connection for session: {}", (Object)sessionId);
                this.sessions.put(sessionId, session);
                LOGGER.info("Sending initial endpoint event to session: {}", (Object)sessionId);
                String endpointUrl = this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId;
                LOGGER.info("Endpoint URL: {}", (Object)endpointUrl);
                ServerSentEvent endpointEvent = ServerSentEvent.builder().event(ENDPOINT_EVENT_TYPE).data((Object)endpointUrl).build();
                LOGGER.info("Built endpoint event - Type: {}, Data: {}", (Object)endpointEvent.event(), endpointEvent.data());
                sink.next((Object)endpointEvent);
                LOGGER.info("Successfully sent initial endpoint event for session: {}", (Object)sessionId);
                sink.onCancel(() -> {
                    LOGGER.info("Session {} cancelled by client", (Object)sessionId);
                    this.sessions.remove(sessionId);
                });
                sink.onDispose(() -> {
                    LOGGER.info("Session {} disposed", (Object)sessionId);
                    this.sessions.remove(sessionId);
                });
            }
            catch (Exception e) {
                LOGGER.error("Error creating SSE session", (Throwable)e);
                sink.error((Throwable)e);
            }
        }).doOnSubscribe(subscription -> LOGGER.info("SSE Flux subscribed")).doOnRequest(n -> LOGGER.debug("SSE Flux requested {} items", (Object)n));
    }

    public Mono<ServerResponse> handleMessage(ServerRequest request) {
        if (this.isClosing) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.SERVICE_UNAVAILABLE).bodyValue((Object)"Server is shutting down");
        }
        if (request.queryParam("sessionId").isEmpty()) {
            return ServerResponse.badRequest().bodyValue((Object)new McpError((Object)"Session ID missing in message endpoint"));
        }
        McpServerSession session = this.sessions.get(request.queryParam("sessionId").get());
        if (Objects.isNull(session)) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.NOT_FOUND).bodyValue((Object)new McpError((Object)("Session not found: " + (String)request.queryParam("sessionId").get())));
        }
        return request.bodyToMono(String.class).flatMap(body -> {
            try {
                McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage((ObjectMapper)this.objectMapper, (String)body);
                return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> {
                    LOGGER.error("Error processing  message: {}", (Object)error.getMessage());
                    return ServerResponse.status((HttpStatusCode)HttpStatus.INTERNAL_SERVER_ERROR).bodyValue((Object)new McpError((Object)error.getMessage()));
                });
            }
            catch (IOException | IllegalArgumentException e) {
                LOGGER.error("Failed to deserialize message: {}", (Object)e.getMessage());
                return ServerResponse.badRequest().bodyValue((Object)new McpError((Object)"Invalid message format"));
            }
        });
    }

    public Mono<MessageHandlingResult> handleMessageEndpoint(ServerRequest request) {
        if (this.isClosing) {
            LOGGER.warn("Server is shutting down, rejecting message request");
            return Mono.just((Object)new MessageHandlingResult(503, "Server is shutting down"));
        }
        if (request.queryParam("sessionId").isEmpty()) {
            LOGGER.warn("Session ID missing in message endpoint");
            return Mono.just((Object)new MessageHandlingResult(400, "Session ID missing in message endpoint"));
        }
        String sessionId = (String)request.queryParam("sessionId").get();
        McpServerSession session = this.sessions.get(sessionId);
        if (Objects.isNull(session)) {
            LOGGER.warn("Session not found: {}", (Object)sessionId);
            return Mono.just((Object)new MessageHandlingResult(404, "Session not found: " + sessionId));
        }
        LOGGER.info("Processing message for session: {}", (Object)sessionId);
        return request.bodyToMono(String.class).flatMap(body -> {
            try {
                LOGGER.debug("Received message body: {}", body);
                McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage((ObjectMapper)this.objectMapper, (String)body);
                LOGGER.info("Deserialized JSON-RPC message for session: {}", (Object)sessionId);
                return session.handle(message).doOnSuccess(result -> LOGGER.info("Successfully processed message for session: {}", (Object)sessionId)).map(response -> new MessageHandlingResult(200, "Message processed successfully")).onErrorResume(error -> {
                    LOGGER.error("Error processing message for session {}: {}", (Object)sessionId, (Object)error.getMessage());
                    return Mono.just((Object)new MessageHandlingResult(500, "Error processing message: " + error.getMessage()));
                });
            }
            catch (IOException | IllegalArgumentException e) {
                LOGGER.error("Failed to deserialize message for session {}: {}", (Object)sessionId, (Object)e.getMessage());
                return Mono.just((Object)new MessageHandlingResult(400, "Invalid message format: " + e.getMessage()));
            }
        }).onErrorResume(error -> {
            LOGGER.error("Unexpected error handling message for session {}: {}", (Object)sessionId, (Object)error.getMessage());
            return Mono.just((Object)new MessageHandlingResult(500, "Unexpected error: " + error.getMessage()));
        });
    }

    public static class Builder {
        private ObjectMapper objectMapper;
        private String baseUrl = "";
        private String messageEndpoint;
        private String sseEndpoint = "/sse";

        public Builder objectMapper(ObjectMapper objectMapper) {
            Assert.notNull((Object)objectMapper, (String)"ObjectMapper must not be null");
            this.objectMapper = objectMapper;
            return this;
        }

        public Builder basePath(String baseUrl) {
            Assert.notNull((Object)baseUrl, (String)"basePath must not be null");
            this.baseUrl = baseUrl;
            return this;
        }

        public Builder messageEndpoint(String messageEndpoint) {
            Assert.notNull((Object)messageEndpoint, (String)"Message endpoint must not be null");
            this.messageEndpoint = messageEndpoint;
            return this;
        }

        public Builder sseEndpoint(String sseEndpoint) {
            Assert.notNull((Object)sseEndpoint, (String)"SSE endpoint must not be null");
            this.sseEndpoint = sseEndpoint;
            return this;
        }

        public ShenyuSseServerTransportProvider build() {
            Assert.notNull((Object)this.objectMapper, (String)"ObjectMapper must be set");
            Assert.notNull((Object)this.messageEndpoint, (String)"Message endpoint must be set");
            return new ShenyuSseServerTransportProvider(this.objectMapper, this.baseUrl, this.messageEndpoint, this.sseEndpoint);
        }
    }

    public static class MessageHandlingResult {
        private final int statusCode;
        private final String responseBody;

        public MessageHandlingResult(int statusCode, String responseBody) {
            this.statusCode = statusCode;
            this.responseBody = responseBody;
        }

        public int getStatusCode() {
            return this.statusCode;
        }

        public String getResponseBody() {
            return this.responseBody;
        }
    }

    private class WebFluxMcpSessionTransport
    implements McpServerTransport {
        private final FluxSink<ServerSentEvent<?>> sink;

        WebFluxMcpSessionTransport(FluxSink<ServerSentEvent<?>> sink) {
            this.sink = sink;
        }

        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            return Mono.fromSupplier(() -> {
                try {
                    return ShenyuSseServerTransportProvider.this.objectMapper.writeValueAsString((Object)message);
                }
                catch (IOException e) {
                    throw Exceptions.propagate((Throwable)e);
                }
            }).doOnNext(jsonText -> {
                ServerSentEvent event = ServerSentEvent.builder().event(ShenyuSseServerTransportProvider.MESSAGE_EVENT_TYPE).data(jsonText).build();
                this.sink.next((Object)event);
            }).doOnError(e -> {
                Throwable exception = Exceptions.unwrap((Throwable)e);
                this.sink.error(exception);
            }).then();
        }

        public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
            return (T)ShenyuSseServerTransportProvider.this.objectMapper.convertValue(data, typeRef);
        }

        public Mono<Void> closeGracefully() {
            return Mono.fromRunnable(() -> this.sink.complete());
        }

        public void close() {
            this.sink.complete();
        }
    }
}

