/*
 * Decompiled with CFR 0.152.
 */
package kd.ai.gai.core.service.agent;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.alipay.api.java_websocket.client.WebSocketClient;
import com.alipay.api.java_websocket.handshake.ServerHandshake;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.Date;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import kd.ai.gai.core.api.websocket.query.GaiRequestContext;
import kd.ai.gai.core.api.websocket.query.RequiredParams;
import kd.ai.gai.core.api.websocket.query.WsRequestParams;
import kd.ai.gai.core.cache.SessionCache;
import kd.ai.gai.core.constant.agent.AgentConstants;
import kd.ai.gai.core.domain.dto.agent.AiccAgentConfig;
import kd.ai.gai.core.domain.dto.agent.FileInfo;
import kd.ai.gai.core.enuz.agent.RunStepStatusEnum;
import kd.ai.gai.core.service.agent.AgentServiceService;
import kd.ai.gai.core.service.agent.FileService;
import kd.ai.gai.core.service.agent.RunService;
import kd.ai.gai.core.service.agent.handler.AgentHandlerFactory;
import kd.ai.gai.core.trace.entity.BaseResult;
import kd.ai.gai.core.util.AgentThreadUtils;
import kd.bos.context.RequestContext;
import kd.bos.logging.Log;
import kd.bos.logging.LogFactory;
import kd.bos.service.KDDateFormatUtils;
import kd.bos.service.KDDateUtils;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;

public class AgentMsgWatchWebSocketService {
    private static final Log log = LogFactory.getLog(AgentMsgWatchWebSocketService.class);
    private static final int MAX_TIME_OUT = 1800000;
    private AgentMsgWatchWebSocketClient agentMsgWatchWebSocketClient;
    BlockingQueue<Object> messageQueue = new LinkedBlockingQueue<Object>();

    private void onMessageQueue(WsRequestParams wsRequestParams) {
        String chatSessionId = this.getChatSessionId();
        RequiredParams requiredParams = wsRequestParams.getRequiredParams();
        String requestId = requiredParams.getRequestId();
        long startTime = KDDateUtils.now().getTime();
        while (KDDateUtils.now().getTime() - startTime < 1800000L) {
            WsRequestParams processWsRequestParams = null;
            try {
                Object obj = this.messageQueue.poll(100L, TimeUnit.MILLISECONDS);
                boolean toBreak = false;
                while (obj != null) {
                    processWsRequestParams = null;
                    startTime = KDDateUtils.now().getTime();
                    String responeRequestId = "";
                    if (obj instanceof String) {
                        processWsRequestParams = this.processText(chatSessionId, (String)obj);
                        if (processWsRequestParams != null) {
                            if (this.webSocketMsgIsEnd(processWsRequestParams)) {
                                toBreak = true;
                            }
                            if (processWsRequestParams.isResponse()) {
                                responeRequestId = processWsRequestParams.getRequiredParams().getRequestId();
                            }
                        }
                    } else {
                        responeRequestId = this.processBinary((ByteBuffer)obj);
                    }
                    if (StringUtils.equalsIgnoreCase((CharSequence)requestId, (CharSequence)responeRequestId)) {
                        Date now = KDDateUtils.now();
                        log.info("\u3010type:{}\u3011, requestId:{} receive response with [{}, {}](cost:{}s)", new Object[]{requiredParams.getType(), requiredParams.getRequestId(), KDDateFormatUtils.getDateTimeFormat().format(new Date(startTime)), KDDateFormatUtils.getDateTimeFormat().format(now), (now.getTime() - startTime) / 1000L});
                    }
                    obj = this.messageQueue.poll(100L, TimeUnit.MILLISECONDS);
                }
                if (!toBreak) continue;
                log.info("toBreak {} with processWsRequestParams: {}", (Object)chatSessionId, (Object)(processWsRequestParams != null ? JSONObject.toJSONString(processWsRequestParams) : null));
                break;
            }
            catch (Exception e) {
                log.error(String.format("\u3010type:%s\u3011, requestId:%s onMessage error : %s", requiredParams.getType(), requiredParams.getRequestId(), e.getMessage()), (Throwable)e);
                this.sendExceptionMsg(processWsRequestParams, chatSessionId, e.getMessage());
            }
        }
        this.closeWs(chatSessionId);
    }

    private void sendExceptionMsg(WsRequestParams wsRequestParams, String chatSessionId, String errMsg) {
        if (wsRequestParams != null) {
            RequiredParams requiredParams = wsRequestParams.getRequiredParams();
            String type = requiredParams.getType();
            String responseType = StringUtils.replace((String)type, (String)WsRequestParams.REQUEST_PREFIX, (String)WsRequestParams.RESPONSE_PREFIX);
            errMsg = String.format("\u6267\u884c\u63a5\u53e3\u5f02\u5e38\uff1a%s - %s", responseType, errMsg);
            requiredParams.setType(responseType);
            requiredParams.setGaiRequestContext(new GaiRequestContext(RequestContext.get()));
            wsRequestParams.setRequiredParams(requiredParams);
            wsRequestParams.setBizParams(BaseResult.fail(errMsg));
            AgentMsgWatchWebSocketService agentMsgwatchWebSocketService = AgentConstants.agentMsgWatchWebSocketServiceMap.get(chatSessionId);
            if (agentMsgwatchWebSocketService != null) {
                agentMsgwatchWebSocketService.sendMsg(chatSessionId, wsRequestParams);
            }
        }
    }

    private WsRequestParams processText(String chatSessionId, String text) {
        log.info("Received WsMessage.text: {}", (Object)text);
        WsRequestParams wsRequestParams = (WsRequestParams)JSON.parseObject((String)text, WsRequestParams.class);
        RequiredParams requiredParams = wsRequestParams.getRequiredParams();
        log.info("Received WsMessage: {} - {}", (Object)requiredParams.getType(), (Object)text);
        if (wsRequestParams.isRequest()) {
            wsRequestParams = AgentHandlerFactory.runHandler(chatSessionId, wsRequestParams);
        }
        return wsRequestParams;
    }

    private boolean webSocketMsgIsEnd(WsRequestParams wsRequestParams) {
        BaseResult baseResult;
        JSONObject dataJo;
        String runStatus;
        RequiredParams requiredParams;
        String type;
        return wsRequestParams != null && wsRequestParams.isResponse() && StringUtils.equalsIgnoreCase((CharSequence)(type = (requiredParams = wsRequestParams.getRequiredParams()).getType()), (CharSequence)"response.session.run.save") && RunStepStatusEnum.isEndStatus(runStatus = (dataJo = (baseResult = (BaseResult)wsRequestParams.getBizParams()).getDataAsJSONObject()).getString("runStatus"));
    }

    private String processBinary(ByteBuffer byteBuffer) {
        String chatSessionId = this.getChatSessionId();
        String tempFileCacheKey = String.format("%s_%s", "tempFile", chatSessionId);
        String fileInfoStr = SessionCache.get().get(tempFileCacheKey);
        FileInfo fileInfo = (FileInfo)JSONObject.parseObject((String)fileInfoStr, FileInfo.class);
        String requestId = fileInfo.getRequestId();
        log.info("processBinary: tempFileCacheKey:{}, requestId:{}, fileInfo:{}", new Object[]{tempFileCacheKey, requestId, fileInfoStr});
        InputStream inputStream = this.byteBufferToInputStream(byteBuffer);
        fileInfo.setInputStream(inputStream);
        fileInfo = FileService.save(fileInfo);
        fileInfo.setInputStream(null);
        SessionCache.get().remove(tempFileCacheKey);
        String type = "response.file.metadata.save";
        WsRequestParams wsRequestParams = new WsRequestParams(type, requestId, new BaseResult(fileInfo));
        AgentMsgWatchWebSocketService agentMsgwatchWebSocketService = AgentConstants.agentMsgWatchWebSocketServiceMap.get(chatSessionId);
        agentMsgwatchWebSocketService.sendMsg(chatSessionId, wsRequestParams);
        return requestId;
    }

    private InputStream byteBufferToInputStream(ByteBuffer byteBuffer) {
        byte[] bytes = new byte[byteBuffer.remaining()];
        byteBuffer.get(bytes);
        return new ByteArrayInputStream(bytes);
    }

    @NotNull
    private String getChatSessionId() {
        String path = this.agentMsgWatchWebSocketClient.getURI().getPath();
        return path.substring(path.lastIndexOf("/") + 1);
    }

    private void connectWs(String chatSessionId, WsRequestParams wsRequestParams) {
        try {
            if (this.agentMsgWatchWebSocketClient == null || this.agentMsgWatchWebSocketClient.isClosed()) {
                AiccAgentConfig aiccAgentConfig = AgentServiceService.getAgentService("gai_agent");
                String wsUri = String.format("%s/agent/msgwatch/%s", aiccAgentConfig.getServerUrl(), chatSessionId);
                this.agentMsgWatchWebSocketClient = new AgentMsgWatchWebSocketClient(new URI(wsUri), this.messageQueue, chatSessionId);
                boolean succ = this.agentMsgWatchWebSocketClient.connectBlocking(10L, TimeUnit.SECONDS);
                if (succ) {
                    log.info("chatSessionId {} create connect success", (Object)chatSessionId);
                    if (wsRequestParams.isRequest()) {
                        log.info("requiredParams.type {} : {} start onMessageQueue", (Object)chatSessionId, (Object)wsRequestParams.getRequiredParams().getType());
                        AgentThreadUtils.execute(() -> this.onMessageQueue(wsRequestParams));
                    }
                }
            } else {
                log.info("chatSessionId {} connect existed and avail.", (Object)chatSessionId);
            }
        }
        catch (Throwable e) {
            log.error(String.format("error : %s", e.getMessage()), e);
        }
    }

    private void closeWs(String chatSessionId) {
        if (this.agentMsgWatchWebSocketClient != null && this.agentMsgWatchWebSocketClient.isOpen()) {
            try {
                log.info("closeWs chatSessionId:{}", (Object)chatSessionId);
                this.agentMsgWatchWebSocketClient.closeBlocking();
            }
            catch (Exception e) {
                log.error(String.format("error : %s", e.getMessage()), (Throwable)e);
            }
            finally {
                log.info("remove agentWebSocketRunner chatSessionId:{}", (Object)chatSessionId);
                this.agentMsgWatchWebSocketClient = null;
                AgentConstants.agentMsgWatchWebSocketServiceMap.remove(chatSessionId);
                RunService.expiredRunOrRunStep(chatSessionId);
            }
        }
    }

    public void sendMsg(String chatSessionId, WsRequestParams wsRequestParams) {
        try {
            String msg = wsRequestParams.toJSONString();
            log.info("chatSessionId:{}, sendMsg: {}", (Object)chatSessionId, (Object)msg);
            this.connectWs(chatSessionId, wsRequestParams);
            this.agentMsgWatchWebSocketClient.send(msg);
        }
        catch (Throwable e) {
            log.error(String.format("sendMsg error : %s", e.getMessage()), e);
        }
    }

    static class AgentMsgWatchWebSocketClient
    extends WebSocketClient {
        private static final Log log = LogFactory.getLog(AgentMsgWatchWebSocketClient.class);
        private static final String ERROR_FLAG = "ERROR:";
        BlockingQueue<Object> messageQueue;
        String chatSessionId;

        public AgentMsgWatchWebSocketClient(URI serverUri, BlockingQueue<Object> messageQueue, String chatSessionId) {
            super(serverUri);
            this.messageQueue = messageQueue;
            this.chatSessionId = chatSessionId;
        }

        public void onMessage(String text) {
            boolean b = this.messageQueue.offer(text);
            if (!b) {
                log.warn("chatSessionId:{} offer message: {} error", (Object)this.chatSessionId, (Object)text);
            }
        }

        public void onMessage(ByteBuffer bytes) {
            boolean b = this.messageQueue.offer(bytes);
            if (!b) {
                log.warn("chatSessionId:{} offer message: bytes error", (Object)this.chatSessionId);
            }
        }

        public void onError(Exception e) {
            boolean b = this.messageQueue.offer(ERROR_FLAG + e.getMessage());
            if (!b) {
                log.warn("chatSessionId:{} offer error to queue ", (Object)this.chatSessionId);
            }
        }

        public void onOpen(ServerHandshake serverHandshake) {
            log.warn("chatSessionId:{} agentMsgWatch onOpen ", (Object)this.chatSessionId);
        }

        public void onClose(int code, String status, boolean b) {
            log.warn("chatSessionId:{} agentMsgWatch onClose ", (Object)this.chatSessionId);
        }
    }
}

