/*
 * Decompiled with CFR 0.152.
 */
package kd.bos.gptas.autoact.agent.executor;

import com.alibaba.fastjson.serializer.SerializerFeature;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import kd.bos.gptas.autoact.agent.AgentContext;
import kd.bos.gptas.autoact.agent.AgentContextImpl;
import kd.bos.gptas.autoact.agent.Agents;
import kd.bos.gptas.autoact.agent.callchain.AgentCall;
import kd.bos.gptas.autoact.agent.callchain.AgentCallImpl;
import kd.bos.gptas.autoact.agent.executor.InvokeData;
import kd.bos.gptas.autoact.agent.executor.MethodAction;
import kd.bos.gptas.autoact.agent.safepoint.SafePoint;
import kd.bos.gptas.autoact.exception.AgentDefineError;
import kd.bos.gptas.autoact.exception.ChatException;
import kd.bos.gptas.autoact.exception.ExceptionUtil;
import kd.bos.gptas.autoact.log.Logable;
import kd.bos.gptas.autoact.memory.ChatMemory;
import kd.bos.gptas.autoact.memory.EmptyChatMemory;
import kd.bos.gptas.autoact.message.AIMessage;
import kd.bos.gptas.autoact.message.ChatMessage;
import kd.bos.gptas.autoact.message.ChatMessageType;
import kd.bos.gptas.autoact.message.SystemMessage;
import kd.bos.gptas.autoact.message.UserMessage;
import kd.bos.gptas.autoact.model.Action;
import kd.bos.gptas.autoact.model.ActionModel;
import kd.bos.gptas.autoact.model.Agent;
import kd.bos.gptas.autoact.model.AgentImpl;
import kd.bos.gptas.autoact.model.BeforeOrAfterActionModel;
import kd.bos.gptas.autoact.model.Tool;
import kd.bos.gptas.autoact.output.OutputParser;
import kd.bos.gptas.autoact.output.OutputParserFactory;
import kd.bos.gptas.autoact.output.parser.StringParser;
import kd.bos.gptas.autoact.prompt.Prompt;
import kd.bos.gptas.autoact.prompt.PromptBuilder;
import kd.bos.gptas.autoact.prompt.var.ContextVarProvider;
import kd.bos.gptas.autoact.util.JsonUtil;
import kd.bos.gptas.autoact.util.RetryUtils;
import kd.bos.gptas.autoact.util.StringUtil;

public final class AgentExecutor
implements Logable {
    private static final AutoCloseable EMPTY_AUTO_CLOSEABLE = () -> {};
    private final LinkedList<AgentContextImpl> ctxQueue = new LinkedList();
    private AgentContextImpl ctx;
    private final Map<Class<?>, Object> objectMap = new ConcurrentHashMap();
    private final Map<String, Action<?, ?>> actionMap = new ConcurrentHashMap();

    public AgentExecutor(AgentContext ctx) {
        this.ctx = (AgentContextImpl)ctx;
    }

    public Object execute(Method method, Object[] args, List<Tool> tools, boolean loggable, AgentCall chainContext) throws Throwable {
        Tool methodTool = ((AgentImpl)this.ctx.getAgent()).findMethodTool(method);
        InvokeData requestData = InvokeData.fromArguments(methodTool, args);
        InvokeData responseData = this.__execute(methodTool, requestData, tools, chainContext, loggable);
        return responseData.toReturn(method.getReturnType(), this.ctx);
    }

    private InvokeData __execute(Tool tool, InvokeData requestData, List<Tool> tools, AgentCall chainContext, boolean loggable) throws Throwable {
        if (tool instanceof Agent) {
            Agent agent = (Agent)((Object)tool);
            try (AutoCloseable ac = this.intoScope(agent);){
                String defaultToolName = agent.getDefaultTool();
                if (StringUtil.isEmpty(agent.getAction()) && !StringUtil.isEmpty(defaultToolName)) {
                    tool = this.ctx.getTool(defaultToolName);
                    if (tool == null) {
                        throw new AgentDefineError("The default tool for the agent \"" + agent.getName() + "\"  does not exists: " + defaultToolName);
                    }
                    if (loggable) {
                        logger.info("agent " + agent.getName() + " turn to the default tool:  " + defaultToolName);
                    }
                    if (tool instanceof Agent) {
                        InvokeData invokeData = this.__execute(tool, requestData, this.ctx.getAllTools(), chainContext, loggable);
                        return invokeData;
                    }
                }
                requestData = this.fireBeforeToolChain(agent.asTool(), requestData, chainContext, loggable);
                requestData = this.__execute0(tool, requestData, tools, chainContext, loggable);
                InvokeData invokeData = requestData = this.fireAfterToolChain(agent.asTool(), requestData, chainContext, loggable);
                return invokeData;
            }
        }
        return this.__execute0(tool, requestData, tools, chainContext, loggable);
    }

    private InvokeData __execute0(Tool tool, InvokeData requestData, List<Tool> tools, AgentCall chainContext, boolean loggable) throws Throwable {
        ActionModel actionModel = tool.getActionModel();
        switch (actionModel) {
            case REACT: {
                Action action = this.getToolAction(tool);
                if (action != null) {
                    return this.__action(action, tool, requestData, chainContext, loggable);
                }
                List<Tool> chooseTools = this.chooseExportedTools(tool, tools);
                if (!chooseTools.isEmpty()) {
                    return this.__choose(tool, requestData, chooseTools, chainContext, loggable);
                }
                return this.__chat(tool, requestData, chainContext, loggable);
            }
            case CHOOSE: {
                List<Tool> chooseTools = this.chooseExportedTools(tool, tools);
                if (!chooseTools.isEmpty()) {
                    return this.__choose(tool, requestData, chooseTools, chainContext, loggable);
                }
                return this.__chat(tool, requestData, chainContext, loggable);
            }
            case ACTION: {
                Action action = this.getToolAction(tool);
                if (action == null) {
                    return this.__chat(tool, requestData, chainContext, loggable);
                }
                return this.__action(action, tool, requestData, chainContext, loggable);
            }
            case CHAT: {
                return this.__chat(tool, requestData, chainContext, loggable);
            }
        }
        throw new UnsupportedOperationException("Unsupported action type: " + (Object)((Object)actionModel));
    }

    private InvokeData __choose(Tool methodTool, InvokeData requestData, List<Tool> tools, AgentCall chainContext, boolean loggable) throws Throwable {
        return this.__choose0(methodTool, requestData, tools, chainContext, false, loggable);
    }

    /*
     * Loose catch block
     */
    private InvokeData __choose0(Tool methodTool, InvokeData requestData, List<Tool> tools, AgentCall chainContext, boolean onlyInvokeMatchedTool, boolean loggable) throws Throwable {
        PromptBuilder chooseToolPrompt = new PromptBuilder();
        chooseToolPrompt.add("Question: \n```html\n").add(StringUtil.join(",", requestData)).add("\n```").addSeparator();
        chooseToolPrompt.addLine("Don't answer the question, just need to chosen the available tools, reply with available tool No.x (x is the tool number), otherwise return No.0: ");
        int i = 0;
        for (Tool tool : tools) {
            HashMap<String, String> descMap = new HashMap<String, String>();
            descMap.put("name", tool.getName());
            PromptBuilder promptBuilder = Prompt.builder();
            if (!StringUtil.isEmpty(tool.getIntentPrompt())) {
                promptBuilder.add(tool.getIntentPrompt());
            } else {
                promptBuilder.add(tool.getActionPrompt());
            }
            String prompt = promptBuilder.build(this.ctx.getContextVarProvider()).getContent();
            descMap.put("description", prompt);
            chooseToolPrompt.add("No.").add(++i).add(Character.valueOf(' ')).add(JsonUtil.toJSONString(descMap, new SerializerFeature[0]));
            chooseToolPrompt.addLine();
        }
        if (loggable) {
            logger.info("try choose tool: " + tools.stream().map(Tool::getName).collect(Collectors.toList()));
        }
        String chosen = ((AIMessage)RetryUtils.withRetry(() -> {
            this.ctx.getSession().getSafePoint().poll();
            return this.ctx.getChatModel().generate(chooseToolPrompt.toString());
        }).getResult()).getMessage();
        this.ctx.getSession().getSafePoint().poll();
        if (loggable) {
            logger.info("try choose tool reply: " + chosen);
        }
        if (chosen.startsWith("No.")) {
            chosen = chosen.substring(3);
        }
        if (chosen.matches("\\d+") && (i = Integer.parseInt(chosen)) > 0) {
            Tool tool;
            tool = tools.get(i - 1);
            if (loggable) {
                logger.info("chosen tool-" + i + ": " + tool.getName());
            }
            if (!onlyInvokeMatchedTool) {
                ((AgentCallImpl)chainContext).addToolCall(methodTool);
            }
            if (tool instanceof Agent) {
                Object input = requestData.toArgument(Object.class, this.ctx);
                this.ctx.getContextVarProvider().setToolInput(tool, input);
                Object ret = null;
                try {
                    try (AutoCloseable ac = this.intoScope((Agent)((Object)tool));){
                        ret = Agents.callAgent(((AgentImpl)tool).getAgentInstance(), input);
                        InvokeData invokeData = InvokeData.fromResult(ret);
                        return invokeData;
                    }
                    {
                        catch (Throwable throwable) {
                            throw throwable;
                        }
                    }
                }
                finally {
                    this.ctx.getContextVarProvider().setToolOutput(tool, ret);
                }
            }
            return this.__execute(tool, requestData, Collections.emptyList(), chainContext, loggable);
        }
        if (loggable) {
            logger.info("No available tool matched, chat with the model.");
        }
        if (onlyInvokeMatchedTool) {
            return requestData;
        }
        return this.__chat(methodTool, requestData, chainContext, loggable);
    }

    private InvokeData __action(Action action, Tool tool, InvokeData requestData, AgentCall chainContext, boolean loggable) throws Throwable {
        InvokeData beforeData = this.fireBeforeToolChain(tool, requestData, chainContext, loggable);
        InvokeData data = this.__action0(action, tool, beforeData, chainContext, loggable);
        InvokeData afterData = this.fireAfterToolChain(tool, data, chainContext, loggable);
        return afterData;
    }

    private InvokeData __action0(Action action, Tool tool, InvokeData requestData, AgentCall chainContext, boolean loggable) throws Throwable {
        Object value;
        this.ctx.getSession().getSafePoint().poll();
        ((AgentCallImpl)chainContext).addToolCall(tool);
        if (action == null) {
            throw new AgentDefineError("The action of tool  \"" + tool.getName() + "\" can't be empty.");
        }
        this.ctx.fireAgentListener(al -> al.onBeforeAction(tool, requestData, this.ctx));
        Object input = requestData.toReturn(Object.class, this.ctx);
        this.ctx.getContextVarProvider().setToolInput(tool, input);
        Class<?> actionParameterType = action.getParameterType();
        Object arg = requestData.toArgument(actionParameterType, this.ctx);
        if (loggable) {
            logger.info("[ACTION call \"" + tool.getAction() + "\"] " + arg);
            value = action.act(arg);
            logger.info("[ACTION return \"" + tool.getAction() + "\"] " + value);
        } else {
            value = action.act(arg);
        }
        this.ctx.getContextVarProvider().setToolOutput(tool, value);
        InvokeData ret = InvokeData.fromResult(value);
        this.ctx.fireAgentListener(al -> al.onAfterAction(tool, ret, this.ctx));
        this.ctx.getSession().getSafePoint().poll();
        return ret;
    }

    private InvokeData __chat(Tool tool, InvokeData requestData, AgentCall chainContext, boolean loggable) throws Throwable {
        InvokeData beforeData = this.fireBeforeToolChain(tool, requestData, chainContext, loggable);
        InvokeData data = this.__chat0(tool, beforeData, chainContext, loggable);
        InvokeData afterData = this.fireAfterToolChain(tool, data, chainContext, loggable);
        return afterData;
    }

    private InvokeData __chat0(Tool tool, InvokeData args, AgentCall chainContext, boolean loggable) throws Throwable {
        OutputParser outputParser;
        String formatPrompt;
        String defaultAction;
        SafePoint safePoint = this.ctx.getSession().getSafePoint();
        safePoint.poll();
        ((AgentCallImpl)chainContext).addToolCall(tool);
        if (tool instanceof Agent && !StringUtil.isEmpty(defaultAction = ((Agent)((Object)tool)).getDefaultAction())) {
            try (AutoCloseable ac = this.intoScope((Agent)((Object)tool));){
                Action action = this.ctx.getAction(defaultAction);
                if (action == null) {
                    throw new AgentDefineError("Agent " + tool.getName() + "'s action " + defaultAction + " not define.");
                }
                InvokeData invokeData = this.__action0(action, tool, args, chainContext, loggable);
                return invokeData;
            }
        }
        this.ctx.fireAgentListener(al -> al.onBeforeChat(tool, args, this.ctx));
        Object input = args.toReturn(Object.class, this.ctx);
        this.ctx.getContextVarProvider().setToolInput(tool, input);
        ChatMemory memory = this.ctx.getMemory();
        boolean useMemory = memory != EmptyChatMemory.INSTANCE;
        ContextVarProvider contextVarProvider = this.ctx.getContextVarProvider();
        List<ChatMessage> historyMessages = memory.messages();
        Agent agent = this.ctx.getAgent();
        if (useMemory && (historyMessages.isEmpty() || historyMessages.get(0).getType() != ChatMessageType.SYSTEM) && !StringUtil.isEmpty(agent.getSystemPrompt())) {
            String sysMessage = Prompt.builder().add(agent.getSystemPrompt()).build(contextVarProvider).getContent();
            memory.add(new SystemMessage(sysMessage));
        }
        PromptBuilder promptBuilder = new PromptBuilder();
        String methodPrompt = tool.genActionPrompt(args, contextVarProvider);
        promptBuilder.add(methodPrompt);
        Class returnType = null;
        if (chainContext.getCalledTools().size() == 1) {
            returnType = chainContext.getMethodName().equalsIgnoreCase("__input__") ? String.class : chainContext.getMethodReturnType();
        }
        if (!StringUtil.isEmpty(formatPrompt = (outputParser = OutputParserFactory.createParser(returnType, tool, this.ctx)).outputFormatPrompt(this.ctx.getChatModel().name()))) {
            promptBuilder.addSeparator().add(formatPrompt);
        }
        String inputMessage = promptBuilder.build(contextVarProvider).getContent();
        ArrayList<ChatMessage> messages = new ArrayList<ChatMessage>(10);
        UserMessage userMessage = new UserMessage(inputMessage);
        if (useMemory) {
            messages.addAll(memory.messages());
            memory.add(userMessage);
        }
        messages.add(userMessage);
        this.ctx.fireAgentListener(al -> al.onBeforeChatMessage(tool, messages, this.ctx));
        if (loggable) {
            logger.info("[CHAT call \"" + tool.getName() + "\"] " + messages);
        }
        if (messages.isEmpty() || messages.size() == 1 && ((ChatMessage)messages.get(0)).getMessageText().isEmpty()) {
            throw new ChatException("message content can not be empty");
        }
        AIMessage aiMessage = (AIMessage)RetryUtils.withRetry(() -> {
            safePoint.poll();
            return this.ctx.getChatModel().generate(messages);
        }).getResult();
        if (useMemory) {
            memory.add(aiMessage, true);
        }
        if (loggable && aiMessage.isStream() && this.ctx.isDebugMessageStream()) {
            System.out.println("[DebugMessageStream]");
            System.out.println("-------------------------------------");
            aiMessage.getMessageStream().forEachRemaining(s -> {
                safePoint.poll();
                System.out.print((String)s);
            });
            System.out.println("\n-------------------------------------");
        }
        safePoint.poll();
        Object value = outputParser instanceof StringParser && aiMessage.isStream() ? aiMessage.getMessageStream() : outputParser.parse(aiMessage.getMessageText());
        this.ctx.getContextVarProvider().setToolOutput(tool, value);
        if (loggable) {
            logger.info("[CHAT return \"" + tool.getName() + "\"] " + value);
        }
        InvokeData ret = InvokeData.fromResult(value);
        this.ctx.fireAgentListener(al -> al.onAfterChat(tool, ret, this.ctx));
        safePoint.poll();
        return ret;
    }

    private InvokeData fireBeforeToolChain(Tool tool, InvokeData args, AgentCall chainContext, boolean loggable) throws Throwable {
        return this.doToolChain("fire before \"" + tool.getName() + '\"', tool.getBeforeActionTools(), tool.getBeforeActionModel(), tool, args, chainContext, loggable);
    }

    private InvokeData fireAfterToolChain(Tool tool, InvokeData args, AgentCall chainContext, boolean loggable) throws Throwable {
        return this.doToolChain("fire  after \"" + tool.getName() + '\"', tool.getAfterActionTools(), tool.getAfterActionModel(), tool, args, chainContext, loggable);
    }

    private InvokeData doToolChain(String logTag, List<String> chainToolNames, BeforeOrAfterActionModel actionModel, Tool tool, InvokeData requestData, AgentCall chainContext, boolean loggable) throws Throwable {
        if (chainToolNames.isEmpty()) {
            return requestData;
        }
        List<Tool> allTools = this.ctx.getAllTools();
        HashMap<String, Tool> toolMap = new HashMap<String, Tool>(allTools.size());
        for (Tool item : allTools) {
            toolMap.put(item.getName(), item);
        }
        switch (actionModel) {
            case FLOW: {
                for (String toolName : chainToolNames) {
                    if (loggable) {
                        logger.info(logTag + " by flow: " + toolName);
                    }
                    Tool chainTool = (Tool)toolMap.get(toolName);
                    List<Tool> tools = this.chooseExportedTools(chainTool, allTools);
                    tools.remove(tool);
                    requestData = this.__execute(chainTool, requestData, tools, chainContext, loggable);
                }
                return requestData;
            }
            case MATCH: {
                if (loggable) {
                    logger.info(logTag + " by match: " + chainToolNames);
                }
                ArrayList<Tool> chooseTools = new ArrayList<Tool>(chainToolNames.size());
                for (String toolName : chainToolNames) {
                    Tool chainTool = (Tool)toolMap.get(toolName);
                    chooseTools.add(chainTool);
                }
                return this.__choose0(tool, requestData, chooseTools, chainContext, true, loggable);
            }
        }
        throw new UnsupportedOperationException("Unsupported action model at " + logTag);
    }

    public Action getToolAction(Tool tool) {
        String actionExpression = tool.getActionExpression(this.ctx);
        if (StringUtil.isEmpty(actionExpression)) {
            return null;
        }
        Action action = this.ctx.getAction(actionExpression);
        if (action != null) {
            return action;
        }
        return this.actionMap.computeIfAbsent(actionExpression, exp -> {
            int p = exp.indexOf(35);
            if (p != -1) {
                Class<?> cls;
                String name = exp.substring(0, p);
                String methodName = exp.substring(p + 1);
                try {
                    cls = Class.forName(name);
                }
                catch (ClassNotFoundException e) {
                    throw new RuntimeException(e);
                }
                Object obj = this.objectMap.computeIfAbsent(cls, clazz -> {
                    if (clazz.isInterface()) {
                        return clazz == this.ctx.getAgentClass() ? ((AgentImpl)this.ctx.getAgent()).getAgentInstance() : null;
                    }
                    try {
                        try {
                            Constructor cc = clazz.getConstructor(AgentContext.class);
                            if (Modifier.isPublic(cc.getModifiers())) {
                                return cc.newInstance(this.ctx);
                            }
                        }
                        catch (NoSuchMethodException cc) {
                            // empty catch block
                        }
                        return clazz.newInstance();
                    }
                    catch (Exception e) {
                        throw ExceptionUtil.asRuntimeException(e);
                    }
                });
                Method[] ms = cls.getMethods();
                for (int i = ms.length - 1; i >= 0; --i) {
                    Method method = ms[i];
                    if (!Modifier.isPublic(method.getModifiers()) || !methodName.equals(method.getName()) || Modifier.isAbstract(method.getModifiers())) continue;
                    return new MethodAction(tool, obj, method, this.ctx);
                }
            }
            throw new IllegalArgumentException("Not define tool: " + exp);
        });
    }

    private List<Tool> chooseExportedTools(Tool methodTool, List<Tool> tools) {
        if (tools.isEmpty()) {
            return Collections.emptyList();
        }
        int parameterCount = methodTool.getParameters().size();
        ArrayList<Tool> chooseTools = new ArrayList<Tool>(tools.size() - 1);
        for (Tool tool : tools) {
            ActionModel toolActionModel;
            if (tool == methodTool || !tool.isExport() || (toolActionModel = tool.getActionModel()) != ActionModel.REACT && toolActionModel != ActionModel.ACTION && toolActionModel != ActionModel.CHAT || tool.getParameters().size() != parameterCount) continue;
            chooseTools.add(tool);
        }
        return chooseTools;
    }

    private AutoCloseable intoScope(Agent agent) {
        AgentContextImpl next = (AgentContextImpl)Agents.getAgentContext(((AgentImpl)agent).getAgentInstance());
        if (next == this.ctx) {
            return EMPTY_AUTO_CLOSEABLE;
        }
        this.ctxQueue.add(this.ctx);
        next.setParent(this.ctx);
        this.ctx = next;
        return () -> {
            this.ctx = this.ctxQueue.removeLast();
        };
    }
}

