/*
 * Decompiled with CFR 0.152.
 */
package kd.ai.mcp.client;

import com.fasterxml.jackson.core.type.TypeReference;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import kd.ai.mcp.client.McpClientFeatures;
import kd.ai.mcp.spec.ClientMcpTransport;
import kd.ai.mcp.spec.McpClientSession;
import kd.ai.mcp.spec.McpError;
import kd.ai.mcp.spec.McpSchema;
import kd.ai.mcp.spec.McpTransport;
import kd.ai.mcp.util.Assert;
import kd.ai.mcp.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.EmitterProcessor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxProcessor;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;

public class McpAsyncClient {
    private static final Logger logger = LoggerFactory.getLogger(McpAsyncClient.class);
    private static final TypeReference<Void> VOID_TYPE_REFERENCE = new TypeReference<Void>(){};
    private static final TypeReference<McpSchema.CallToolResult> CALL_TOOL_RESULT_TYPE_REF = new TypeReference<McpSchema.CallToolResult>(){};
    private static final TypeReference<McpSchema.ListToolsResult> LIST_TOOLS_RESULT_TYPE_REF = new TypeReference<McpSchema.ListToolsResult>(){};
    private static final TypeReference<McpSchema.ListResourcesResult> LIST_RESOURCES_RESULT_TYPE_REF = new TypeReference<McpSchema.ListResourcesResult>(){};
    private static final TypeReference<McpSchema.ReadResourceResult> READ_RESOURCE_RESULT_TYPE_REF = new TypeReference<McpSchema.ReadResourceResult>(){};
    private static final TypeReference<McpSchema.ListResourceTemplatesResult> LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF = new TypeReference<McpSchema.ListResourceTemplatesResult>(){};
    private static final TypeReference<McpSchema.ListPromptsResult> LIST_PROMPTS_RESULT_TYPE_REF = new TypeReference<McpSchema.ListPromptsResult>(){};
    private static final TypeReference<McpSchema.GetPromptResult> GET_PROMPT_RESULT_TYPE_REF = new TypeReference<McpSchema.GetPromptResult>(){};
    private final FluxProcessor<McpSchema.InitializeResult, McpSchema.InitializeResult> initializeProcessor;
    private final FluxSink<McpSchema.InitializeResult> initializeSink;
    private final Mono<McpSchema.InitializeResult> cachedInitResult;
    private final Duration initializationTimeout;
    private final McpClientSession mcpSession;
    private final McpSchema.ClientCapabilities clientCapabilities;
    private final McpSchema.Implementation clientInfo;
    private final ConcurrentHashMap<String, McpSchema.Root> roots;
    private final McpTransport transport;
    private final AtomicBoolean initialized = new AtomicBoolean(false);
    private McpSchema.ServerCapabilities serverCapabilities;
    private McpSchema.Implementation serverInfo;
    private Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler;
    private List<String> protocolVersions = Collections.singletonList("2024-11-05");

    McpAsyncClient(ClientMcpTransport transport, Duration requestTimeout, Duration initializationTimeout, McpClientFeatures.Async features) {
        Assert.notNull(transport, "Transport must not be null");
        Assert.notNull(requestTimeout, "Request timeout must not be null");
        Assert.notNull(initializationTimeout, "Initialization timeout must not be null");
        this.clientInfo = features.clientInfo();
        this.clientCapabilities = features.clientCapabilities();
        this.transport = transport;
        this.roots = new ConcurrentHashMap<String, McpSchema.Root>(features.roots());
        this.initializationTimeout = initializationTimeout;
        this.initializeProcessor = EmitterProcessor.create((boolean)false);
        this.initializeSink = this.initializeProcessor.sink();
        this.cachedInitResult = this.initializeProcessor.next().timeout(this.initializationTimeout).cache();
        HashMap requestHandlers = new HashMap();
        if (this.clientCapabilities.roots() != null) {
            requestHandlers.put("roots/list", this.rootsListRequestHandler());
        }
        if (this.clientCapabilities.sampling() != null) {
            if (features.samplingHandler() == null) {
                throw new McpError((Object)"Sampling handler must not be null when client capabilities include sampling");
            }
            this.samplingHandler = features.samplingHandler();
            requestHandlers.put("sampling/createMessage", this.samplingCreateMessageHandler());
        }
        HashMap<String, McpClientSession.NotificationHandler> notificationHandlers = new HashMap<String, McpClientSession.NotificationHandler>();
        if (!Utils.isEmpty(features.toolsChangeConsumers())) {
            notificationHandlers.put("notifications/tools/list_changed", this.asyncToolsListChangedNotificationHandler(features.toolsChangeConsumers()));
        }
        if (!Utils.isEmpty(features.resourcesChangeConsumers())) {
            notificationHandlers.put("notifications/resources/list_changed", this.asyncResourcesListChangedNotificationHandler(features.resourcesChangeConsumers()));
        }
        if (!Utils.isEmpty(features.promptsChangeConsumers())) {
            notificationHandlers.put("notifications/prompts/list_changed", this.asyncPromptsListChangedNotificationHandler(features.promptsChangeConsumers()));
        }
        if (!Utils.isEmpty(features.rootsChangeConsumers())) {
            notificationHandlers.put("notifications/roots/list_changed", this.asyncRootsListChangedNotificationHandler(features.rootsChangeConsumers()));
        }
        ArrayList<Function<McpSchema.LoggingMessageNotification, Mono<Void>>> loggingConsumersFinal = new ArrayList<Function<McpSchema.LoggingMessageNotification, Mono<Void>>>();
        loggingConsumersFinal.add(notification -> Mono.fromRunnable(() -> logger.debug("Logging: {}", notification)));
        if (!Utils.isEmpty(features.loggingConsumers())) {
            loggingConsumersFinal.addAll(features.loggingConsumers());
        }
        notificationHandlers.put("notifications/message", this.asyncLoggingNotificationHandler(loggingConsumersFinal));
        this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers);
    }

    public McpSchema.ServerCapabilities getServerCapabilities() {
        return this.serverCapabilities;
    }

    public McpSchema.Implementation getServerInfo() {
        return this.serverInfo;
    }

    public boolean isInitialized() {
        return this.initialized.get();
    }

    public McpSchema.ClientCapabilities getClientCapabilities() {
        return this.clientCapabilities;
    }

    public McpSchema.Implementation getClientInfo() {
        return this.clientInfo;
    }

    public void close() {
        this.mcpSession.close();
    }

    public Mono<Void> closeGracefully() {
        return Mono.fromRunnable(() -> {
            this.initializeSink.complete();
            this.mcpSession.closeGracefully().block();
        });
    }

    public Mono<McpSchema.InitializeResult> initialize() {
        String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1);
        McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(latestVersion, this.clientCapabilities, this.clientInfo);
        return this.mcpSession.sendRequest("initialize", initializeRequest, new TypeReference<McpSchema.InitializeResult>(){}).flatMap(initializeResult -> {
            this.serverCapabilities = initializeResult.capabilities();
            this.serverInfo = initializeResult.serverInfo();
            logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", new Object[]{initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(), initializeResult.instructions()});
            if (!this.protocolVersions.contains(initializeResult.protocolVersion())) {
                return Mono.error((Throwable)new McpError((Object)("Unsupported protocol version from the server: " + initializeResult.protocolVersion())));
            }
            this.initializeSink.next(initializeResult);
            this.initialized.set(true);
            return this.mcpSession.sendNotification("notifications/initialized", null).thenReturn(initializeResult);
        });
    }

    private <T> Mono<T> withInitializationCheck(String actionName, Function<McpSchema.InitializeResult, Mono<T>> operation) {
        return this.cachedInitResult.onErrorResume(throwable -> Mono.error((Throwable)new McpError((Object)("Client must be initialized before " + actionName)))).flatMap(operation);
    }

    public Mono<Object> ping() {
        return this.withInitializationCheck("pinging the server", initializedResult -> this.mcpSession.sendRequest("ping", null, new TypeReference<Object>(){}));
    }

    public Mono<Void> addRoot(McpSchema.Root root) {
        if (root == null) {
            return Mono.error((Throwable)new McpError((Object)"Root must not be null"));
        }
        if (this.clientCapabilities.roots() == null) {
            return Mono.error((Throwable)new McpError((Object)"Client must be configured with roots capabilities"));
        }
        if (this.roots.containsKey(root.uri())) {
            return Mono.error((Throwable)new McpError((Object)("Root with uri '" + root.uri() + "' already exists")));
        }
        this.roots.put(root.uri(), root);
        logger.debug("Added root: {}", (Object)root);
        if (this.clientCapabilities.roots().listChanged()) {
            if (this.isInitialized()) {
                return this.rootsListChangedNotification();
            }
            logger.warn("Client is not initialized, ignore sending a roots list changed notification");
        }
        return Mono.empty();
    }

    public Mono<Void> removeRoot(String rootUri) {
        if (rootUri == null) {
            return Mono.error((Throwable)new McpError((Object)"Root uri must not be null"));
        }
        if (this.clientCapabilities.roots() == null) {
            return Mono.error((Throwable)new McpError((Object)"Client must be configured with roots capabilities"));
        }
        McpSchema.Root removed = this.roots.remove(rootUri);
        if (removed != null) {
            logger.debug("Removed Root: {}", (Object)rootUri);
            if (this.clientCapabilities.roots().listChanged()) {
                if (this.isInitialized()) {
                    return this.rootsListChangedNotification();
                }
                logger.warn("Client is not initialized, ignore sending a roots list changed notification");
            }
            return Mono.empty();
        }
        return Mono.error((Throwable)new McpError((Object)("Root with uri '" + rootUri + "' not found")));
    }

    public Mono<Void> rootsListChangedNotification() {
        return this.withInitializationCheck("sending roots list changed notification", initResult -> this.mcpSession.sendNotification("notifications/roots/list_changed"));
    }

    private McpClientSession.RequestHandler<McpSchema.ListRootsResult> rootsListRequestHandler() {
        return params -> {
            McpSchema.PaginatedRequest request = this.transport.unmarshalFrom(params, new TypeReference<McpSchema.PaginatedRequest>(){});
            ArrayList<McpSchema.Root> roots = new ArrayList<McpSchema.Root>(this.roots.values());
            return Mono.just((Object)new McpSchema.ListRootsResult(roots));
        };
    }

    private McpClientSession.RequestHandler<McpSchema.CreateMessageResult> samplingCreateMessageHandler() {
        return params -> {
            McpSchema.CreateMessageRequest request = this.transport.unmarshalFrom(params, new TypeReference<McpSchema.CreateMessageRequest>(){});
            return this.samplingHandler.apply(request);
        };
    }

    public Mono<McpSchema.CallToolResult> callTool(McpSchema.CallToolRequest callToolRequest) {
        return this.withInitializationCheck("calling tools", initializedResult -> {
            if (this.serverCapabilities.tools() == null) {
                return Mono.error((Throwable)new McpError((Object)"Server does not provide tools capability"));
            }
            return this.mcpSession.sendRequest("tools/call", callToolRequest, CALL_TOOL_RESULT_TYPE_REF);
        });
    }

    public Mono<McpSchema.ListToolsResult> listTools() {
        return this.listTools(null);
    }

    public Mono<McpSchema.ListToolsResult> listTools(String cursor) {
        return this.withInitializationCheck("listing tools", initializedResult -> {
            if (this.serverCapabilities.tools() == null) {
                return Mono.error((Throwable)new McpError((Object)"Server does not provide tools capability"));
            }
            return this.mcpSession.sendRequest("tools/list", new McpSchema.PaginatedRequest(cursor), LIST_TOOLS_RESULT_TYPE_REF);
        });
    }

    private McpClientSession.NotificationHandler asyncToolsListChangedNotificationHandler(List<Function<List<McpSchema.Tool>, Mono<Void>>> toolsChangeConsumers) {
        return params -> this.listTools().flatMap(listToolsResult -> Flux.fromIterable((Iterable)toolsChangeConsumers).flatMap(consumer -> (Mono)consumer.apply(listToolsResult.tools())).then());
    }

    public Mono<McpSchema.ListResourcesResult> listResources() {
        return this.listResources(null);
    }

    public Mono<McpSchema.ListResourcesResult> listResources(String cursor) {
        return this.withInitializationCheck("listing resources", initializedResult -> {
            if (this.serverCapabilities.resources() == null) {
                return Mono.error((Throwable)new McpError((Object)"Server does not provide the resources capability"));
            }
            return this.mcpSession.sendRequest("resources/list", new McpSchema.PaginatedRequest(cursor), LIST_RESOURCES_RESULT_TYPE_REF);
        });
    }

    public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.Resource resource) {
        return this.readResource(new McpSchema.ReadResourceRequest(resource.uri()));
    }

    public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.ReadResourceRequest readResourceRequest) {
        return this.withInitializationCheck("reading resources", initializedResult -> {
            if (this.serverCapabilities.resources() == null) {
                return Mono.error((Throwable)new McpError((Object)"Server does not provide the resources capability"));
            }
            return this.mcpSession.sendRequest("resources/read", readResourceRequest, READ_RESOURCE_RESULT_TYPE_REF);
        });
    }

    public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates() {
        return this.listResourceTemplates(null);
    }

    public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates(String cursor) {
        return this.withInitializationCheck("listing resource templates", initializedResult -> {
            if (this.serverCapabilities.resources() == null) {
                return Mono.error((Throwable)new McpError((Object)"Server does not provide the resources capability"));
            }
            return this.mcpSession.sendRequest("resources/templates/list", new McpSchema.PaginatedRequest(cursor), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF);
        });
    }

    public Mono<Void> subscribeResource(McpSchema.SubscribeRequest subscribeRequest) {
        return this.withInitializationCheck("subscribing to resources", initializedResult -> this.mcpSession.sendRequest("resources/subscribe", subscribeRequest, VOID_TYPE_REFERENCE));
    }

    public Mono<Void> unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) {
        return this.withInitializationCheck("unsubscribing from resources", initializedResult -> this.mcpSession.sendRequest("resources/unsubscribe", unsubscribeRequest, VOID_TYPE_REFERENCE));
    }

    private McpClientSession.NotificationHandler asyncResourcesListChangedNotificationHandler(List<Function<List<McpSchema.Resource>, Mono<Void>>> resourcesChangeConsumers) {
        return params -> this.listResources().flatMap(listResourcesResult -> Flux.fromIterable((Iterable)resourcesChangeConsumers).flatMap(consumer -> (Mono)consumer.apply(listResourcesResult.resources())).then());
    }

    public Mono<McpSchema.ListPromptsResult> listPrompts() {
        return this.listPrompts(null);
    }

    public Mono<McpSchema.ListPromptsResult> listPrompts(String cursor) {
        return this.withInitializationCheck("listing prompts", initializedResult -> this.mcpSession.sendRequest("prompts/list", new McpSchema.PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF));
    }

    public Mono<McpSchema.GetPromptResult> getPrompt(McpSchema.GetPromptRequest getPromptRequest) {
        return this.withInitializationCheck("getting prompts", initializedResult -> this.mcpSession.sendRequest("prompts/get", getPromptRequest, GET_PROMPT_RESULT_TYPE_REF));
    }

    private McpClientSession.NotificationHandler asyncPromptsListChangedNotificationHandler(List<Function<List<McpSchema.Prompt>, Mono<Void>>> promptsChangeConsumers) {
        return params -> this.listPrompts().flatMap(listPromptsResult -> Flux.fromIterable((Iterable)promptsChangeConsumers).flatMap(consumer -> (Mono)consumer.apply(listPromptsResult.prompts())).then());
    }

    private McpClientSession.NotificationHandler asyncLoggingNotificationHandler(List<Function<McpSchema.LoggingMessageNotification, Mono<Void>>> loggingConsumers) {
        return params -> {
            McpSchema.LoggingMessageNotification notification = this.transport.unmarshalFrom(params, new TypeReference<McpSchema.LoggingMessageNotification>(){});
            return Flux.fromIterable((Iterable)loggingConsumers).flatMap(consumer -> (Mono)consumer.apply(notification)).then();
        };
    }

    public Mono<Void> setLoggingLevel(McpSchema.LoggingLevel loggingLevel) {
        if (loggingLevel == null) {
            return Mono.error((Throwable)new McpError((Object)"Logging level must not be null"));
        }
        return this.withInitializationCheck("setting logging level", initializedResult -> {
            String levelName = this.transport.unmarshalFrom((Object)loggingLevel, new TypeReference<String>(){});
            HashMap<String, Object> params = new HashMap<String, Object>();
            params.put("level", levelName);
            return this.mcpSession.sendNotification("logging/setLevel", params);
        });
    }

    void setProtocolVersions(List<String> protocolVersions) {
        this.protocolVersions = protocolVersions;
    }

    private McpClientSession.NotificationHandler asyncRootsListChangedNotificationHandler(List<Function<List<McpSchema.Root>, Mono<Void>>> rootsChangeConsumers) {
        return params -> Flux.fromIterable((Iterable)rootsChangeConsumers).flatMap(consumer -> (Mono)consumer.apply(new ArrayList<McpSchema.Root>(this.roots.values()))).then();
    }
}

