/*
 * Decompiled with CFR 0.152.
 */
package kd.ai.mcp.server.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.servlet.AsyncContext;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import kd.ai.mcp.client.transport.HttpUtils;
import kd.ai.mcp.spec.McpError;
import kd.ai.mcp.spec.McpSchema;
import kd.ai.mcp.spec.McpServerSession;
import kd.ai.mcp.spec.McpServerTransport;
import kd.ai.mcp.spec.McpServerTransportProvider;
import kd.ai.mcp.spec.ServerMcpTransport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;

@WebServlet(asyncSupported=true)
public class HttpServletSseServerTransportProvider
extends McpServerTransportProvider {
    public static final String UTF_8 = "UTF-8";
    public static final String APPLICATION_JSON = "application/json";
    public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}";
    public static final String DEFAULT_SSE_ENDPOINT = "/sse";
    private final String Access_Control_Allow_Origin = "*";
    public static final String MESSAGE_EVENT_TYPE = "message";
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";
    private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransportProvider.class);
    private final ObjectMapper objectMapper;
    private final String messageEndpoint;
    private final String sseEndpoint;
    private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<String, McpServerSession>();
    private final AtomicBoolean isClosing = new AtomicBoolean(false);
    private McpServerSession.Factory sessionFactory;
    private final HttpServlet servletDelegate = new HttpServlet(){

        protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
            HttpServletSseServerTransportProvider.this.doGet(request, response);
        }

        protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
            HttpServletSseServerTransportProvider.this.doPost(request, response);
        }

        public void destroy() {
            HttpServletSseServerTransportProvider.this.destroy();
        }
    };

    public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
        this.objectMapper = objectMapper;
        this.messageEndpoint = messageEndpoint;
        this.sseEndpoint = sseEndpoint;
    }

    public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
        this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
    }

    @Override
    public void setSessionFactory(McpServerSession.Factory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    @Override
    public Mono<Void> notifyClients(String method, Map<String, Object> params) {
        if (this.sessions.isEmpty()) {
            logger.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        logger.debug("Attempting to broadcast message to {} active sessions", (Object)this.sessions.size());
        return Mono.fromRunnable(() -> this.sessions.values().forEach(session -> session.sendNotification(method, params).block()));
    }

    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        String pathInfo = request.getPathInfo();
        if (!this.sseEndpoint.equals(pathInfo)) {
            response.sendError(404);
            return;
        }
        if (this.isClosing.get()) {
            response.sendError(503, "Server is shutting down");
            return;
        }
        HttpUtils.setResponseHeaders(response);
        String sessionId = UUID.randomUUID().toString();
        AsyncContext asyncContext = request.startAsync();
        asyncContext.setTimeout(0L);
        PrintWriter writer = response.getWriter();
        HttpServletMcpSessionTransport sessionTransport = new HttpServletMcpSessionTransport(sessionId, asyncContext, writer);
        McpServerSession session = this.sessionFactory.create(sessionTransport);
        this.sessions.put(sessionId, session);
        this.sendEvent(writer, ENDPOINT_EVENT_TYPE, this.messageEndpoint + "?sessionId=" + sessionId);
    }

    protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        if (this.isClosing.get()) {
            response.sendError(503, "Server is shutting down");
            return;
        }
        String pathInfo = request.getPathInfo();
        if (!this.messageEndpoint.equals(pathInfo)) {
            response.sendError(404);
            return;
        }
        String sessionId = request.getParameter("sessionId");
        if (sessionId == null) {
            response.setContentType(APPLICATION_JSON);
            response.setCharacterEncoding(UTF_8);
            response.setStatus(400);
            String jsonError = this.objectMapper.writeValueAsString((Object)new McpError((Object)"Session ID missing in message endpoint"));
            PrintWriter writer = response.getWriter();
            writer.write(jsonError);
            writer.flush();
            return;
        }
        McpServerSession session = this.sessions.get(sessionId);
        if (session == null) {
            response.setContentType(APPLICATION_JSON);
            response.setCharacterEncoding(UTF_8);
            response.setStatus(404);
            String jsonError = this.objectMapper.writeValueAsString((Object)new McpError((Object)("Session not found: " + sessionId)));
            PrintWriter writer = response.getWriter();
            writer.write(jsonError);
            writer.flush();
            return;
        }
        try {
            String line;
            BufferedReader reader = request.getReader();
            StringBuilder body = new StringBuilder();
            while ((line = reader.readLine()) != null) {
                body.append(line);
            }
            McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, body.toString());
            session.handle(message).block();
            response.setStatus(200);
        }
        catch (Exception e) {
            logger.error("Error processing message: {}", (Object)e.getMessage());
            try {
                McpError mcpError = new McpError((Object)e.getMessage());
                response.setContentType(APPLICATION_JSON);
                response.setCharacterEncoding(UTF_8);
                response.setStatus(500);
                String jsonError = this.objectMapper.writeValueAsString((Object)mcpError);
                PrintWriter writer = response.getWriter();
                writer.write(jsonError);
                writer.flush();
            }
            catch (IOException ex) {
                logger.error(FAILED_TO_SEND_ERROR_RESPONSE, (Object)ex.getMessage());
                response.sendError(500, "Error processing message");
            }
        }
    }

    @Override
    public Mono<Void> closeGracefully() {
        this.isClosing.set(true);
        logger.debug("Initiating graceful shutdown with {} active sessions", (Object)this.sessions.size());
        return Mono.fromRunnable(() -> this.sessions.values().forEach(session -> {
            session.closeGracefully().block();
            this.sessions.remove(session.getId());
        }));
    }

    private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException {
        writer.write("event: " + eventType + "\n");
        writer.write("data: " + data + "\n\n");
        writer.flush();
        if (writer.checkError()) {
            throw new IOException("Client disconnected");
        }
    }

    public void destroy() {
        this.closeGracefully().block();
    }

    @Override
    public ServerMcpTransport createTransport() {
        throw new UnsupportedOperationException("This transport provider does not support direct transport creation");
    }

    private class HttpServletMcpSessionTransport
    implements McpServerTransport {
        private final String sessionId;
        private final AsyncContext asyncContext;
        private final PrintWriter writer;

        HttpServletMcpSessionTransport(String sessionId, AsyncContext asyncContext, PrintWriter writer) {
            this.sessionId = sessionId;
            this.asyncContext = asyncContext;
            this.writer = writer;
            logger.debug("Session transport {} initialized with SSE writer", (Object)sessionId);
        }

        @Override
        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            return Mono.fromRunnable(() -> {
                try {
                    String jsonText = HttpServletSseServerTransportProvider.this.objectMapper.writeValueAsString((Object)message);
                    HttpServletSseServerTransportProvider.this.sendEvent(this.writer, HttpServletSseServerTransportProvider.MESSAGE_EVENT_TYPE, jsonText);
                    logger.debug("Message sent to session {}", (Object)this.sessionId);
                }
                catch (Exception e) {
                    logger.error("Failed to send message to session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
                    HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
                    this.asyncContext.complete();
                }
            });
        }

        @Override
        public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
            return (T)HttpServletSseServerTransportProvider.this.objectMapper.convertValue(data, typeRef);
        }

        @Override
        public Mono<Void> closeGracefully() {
            return Mono.fromRunnable(() -> {
                logger.debug("Closing session transport: {}", (Object)this.sessionId);
                try {
                    HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
                    this.asyncContext.complete();
                    logger.debug("Successfully completed async context for session {}", (Object)this.sessionId);
                }
                catch (Exception e) {
                    logger.warn("Failed to complete async context for session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
                }
            });
        }

        @Override
        public void close() {
            try {
                HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
                this.asyncContext.complete();
                logger.debug("Successfully completed async context for session {}", (Object)this.sessionId);
            }
            catch (Exception e) {
                logger.warn("Failed to complete async context for session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
            }
        }
    }
}

