/*
 * Decompiled with CFR 0.152.
 */
package kd.ai.mcp.spec;

import com.fasterxml.jackson.core.type.TypeReference;
import java.time.Duration;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import kd.ai.mcp.server.McpAsyncServerExchange;
import kd.ai.mcp.spec.McpError;
import kd.ai.mcp.spec.McpSchema;
import kd.ai.mcp.spec.McpServerTransport;
import kd.ai.mcp.spec.McpSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.EmitterProcessor;
import reactor.core.publisher.FluxProcessor;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;

public class McpServerSession
implements McpSession {
    private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class);
    private final ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap();
    private final String id;
    private final AtomicLong requestCounter = new AtomicLong(0L);
    private final InitRequestHandler initRequestHandler;
    private final InitNotificationHandler initNotificationHandler;
    private final Map<String, RequestHandler<?>> requestHandlers;
    private final Map<String, NotificationHandler> notificationHandlers;
    private final McpServerTransport transport;
    private final FluxProcessor<McpAsyncServerExchange, McpAsyncServerExchange> exchangeProcessor = EmitterProcessor.create((boolean)false);
    private final FluxSink<McpAsyncServerExchange> exchangeSink;
    private final AtomicReference<McpSchema.ClientCapabilities> clientCapabilities = new AtomicReference();
    private final AtomicReference<McpSchema.Implementation> clientInfo = new AtomicReference();
    private static final int STATE_UNINITIALIZED = 0;
    private static final int STATE_INITIALIZING = 1;
    private static final int STATE_INITIALIZED = 2;
    private final AtomicInteger state = new AtomicInteger(0);

    public McpServerSession(String id, McpServerTransport transport, InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers) {
        this.id = id;
        this.transport = transport;
        this.initRequestHandler = initHandler;
        this.initNotificationHandler = initNotificationHandler;
        this.requestHandlers = requestHandlers;
        this.notificationHandlers = notificationHandlers;
        this.exchangeSink = this.exchangeProcessor.sink();
    }

    public String getId() {
        return this.id;
    }

    public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) {
        this.clientCapabilities.lazySet(clientCapabilities);
        this.clientInfo.lazySet(clientInfo);
    }

    private String generateRequestId() {
        return this.id + "-" + this.requestCounter.getAndIncrement();
    }

    @Override
    public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef) {
        String requestId = this.generateRequestId();
        return Mono.create(sink -> {
            this.pendingResponses.put(requestId, (MonoSink<McpSchema.JSONRPCResponse>)sink);
            McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest("2.0", method, requestId, requestParams);
            this.transport.sendMessage(jsonrpcRequest).subscribe(unused -> {}, error -> {
                this.pendingResponses.remove(requestId);
                sink.error(error);
            });
        }).timeout(Duration.ofSeconds(10L)).handle((jsonRpcResponse, sink) -> {
            if (jsonRpcResponse.error() != null) {
                sink.error((Throwable)new McpError(jsonRpcResponse.error()));
            } else if (typeRef.getType().equals(Void.class)) {
                sink.complete();
            } else {
                sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef));
            }
        });
    }

    @Override
    public Mono<Void> sendNotification(String method, Map<String, Object> params) {
        McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification("2.0", method, params);
        return this.transport.sendMessage(jsonrpcNotification);
    }

    public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
        return Mono.defer(() -> {
            if (message instanceof McpSchema.JSONRPCResponse) {
                McpSchema.JSONRPCResponse response2 = (McpSchema.JSONRPCResponse)message;
                logger.debug("Received Response: {}", (Object)response2);
                MonoSink<McpSchema.JSONRPCResponse> sink = this.pendingResponses.remove(response2.id());
                if (sink == null) {
                    logger.warn("Unexpected response for unknown id {}", response2.id());
                } else {
                    sink.success((Object)response2);
                }
                return Mono.empty();
            }
            if (message instanceof McpSchema.JSONRPCRequest) {
                McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest)message;
                logger.debug("Received request: {}", (Object)request);
                return this.handleIncomingRequest(request).onErrorResume(error -> {
                    McpSchema.JSONRPCResponse errorResponse = new McpSchema.JSONRPCResponse("2.0", request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(-32603, error.getMessage(), null));
                    return this.transport.sendMessage(errorResponse).then(Mono.empty());
                }).flatMap(response -> this.transport.sendMessage((McpSchema.JSONRPCMessage)response));
            }
            if (message instanceof McpSchema.JSONRPCNotification) {
                McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification)message;
                logger.debug("Received notification: {}", (Object)notification);
                return this.handleIncomingNotification(notification).doOnError(error -> logger.error("Error handling notification: {}", (Object)error.getMessage()));
            }
            logger.warn("Received unknown message type: {}", (Object)message);
            return Mono.empty();
        });
    }

    private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCRequest request) {
        return Mono.defer(() -> {
            Mono<McpSchema.InitializeResult> resultMono;
            if ("initialize".equals(request.method())) {
                McpSchema.InitializeRequest initializeRequest = this.transport.unmarshalFrom(request.params(), new TypeReference<McpSchema.InitializeRequest>(){});
                this.state.lazySet(1);
                this.init(initializeRequest.capabilities(), initializeRequest.clientInfo());
                resultMono = this.initRequestHandler.handle(initializeRequest);
            } else {
                RequestHandler<?> handler = this.requestHandlers.get(request.method());
                if (handler == null) {
                    MethodNotFoundError error2 = McpServerSession.getMethodNotFoundError(request.method());
                    return Mono.just((Object)new McpSchema.JSONRPCResponse("2.0", request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(-32601, error2.getMessage(), error2.getData())));
                }
                resultMono = Mono.from(this.exchangeProcessor).flatMap(exchange -> handler.handle((McpAsyncServerExchange)exchange, request.params()));
            }
            return resultMono.map(result -> new McpSchema.JSONRPCResponse("2.0", request.id(), result, null)).onErrorResume(error -> Mono.just((Object)new McpSchema.JSONRPCResponse("2.0", request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(-32603, error.getMessage(), null))));
        });
    }

    private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification notification) {
        return Mono.defer(() -> {
            if ("notifications/initialized".equals(notification.method())) {
                this.state.lazySet(2);
                this.exchangeSink.next((Object)new McpAsyncServerExchange(this, this.clientCapabilities.get(), this.clientInfo.get()));
                return this.initNotificationHandler.handle();
            }
            NotificationHandler handler = this.notificationHandlers.get(notification.method());
            if (handler == null) {
                logger.error("No handler registered for notification method: {}", (Object)notification.method());
                return Mono.empty();
            }
            return Mono.from(this.exchangeProcessor).flatMap(exchange -> handler.handle((McpAsyncServerExchange)exchange, notification.params()));
        });
    }

    static MethodNotFoundError getMethodNotFoundError(String method) {
        switch (method) {
            case "roots/list": {
                return new MethodNotFoundError(method, "Roots not supported", Collections.singletonMap("reason", "Client does not have roots capability"));
            }
        }
        return new MethodNotFoundError(method, "Method not found: " + method, null);
    }

    @Override
    public Mono<Void> closeGracefully() {
        return this.transport.closeGracefully();
    }

    @Override
    public void close() {
        this.transport.close();
    }

    @FunctionalInterface
    public static interface Factory {
        public McpServerSession create(McpServerTransport var1);
    }

    public static interface RequestHandler<T> {
        public Mono<T> handle(McpAsyncServerExchange var1, Object var2);
    }

    public static interface NotificationHandler {
        public Mono<Void> handle(McpAsyncServerExchange var1, Object var2);
    }

    public static interface InitNotificationHandler {
        public Mono<Void> handle();
    }

    public static interface InitRequestHandler {
        public Mono<McpSchema.InitializeResult> handle(McpSchema.InitializeRequest var1);
    }

    private static final class MethodNotFoundError {
        private final String method;
        private final String message;
        private final Object data;

        public MethodNotFoundError(String method, String message, Object data) {
            this.method = method;
            this.message = message;
            this.data = data;
        }

        public String getMethod() {
            return this.method;
        }

        public String getMessage() {
            return this.message;
        }

        public Object getData() {
            return this.data;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            MethodNotFoundError that = (MethodNotFoundError)o;
            return Objects.equals(this.method, that.method) && Objects.equals(this.message, that.message) && Objects.equals(this.data, that.data);
        }

        public int hashCode() {
            return Objects.hash(this.method, this.message, this.data);
        }

        public String toString() {
            return "MethodNotFoundError{method='" + this.method + '\'' + ", message='" + this.message + '\'' + ", data=" + this.data + '}';
        }
    }
}

