package kd.ai.gai.core.service.agent;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.alipay.api.java_websocket.client.WebSocketClient;
import com.alipay.api.java_websocket.handshake.ServerHandshake;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.Date;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import kd.ai.gai.core.api.websocket.query.GaiRequestContext;
import kd.ai.gai.core.api.websocket.query.RequiredParams;
import kd.ai.gai.core.api.websocket.query.WsRequestParams;
import kd.ai.gai.core.cache.SessionCache;
import kd.ai.gai.core.constant.agent.AgentConstants;
import kd.ai.gai.core.domain.dto.agent.FileInfo;
import kd.ai.gai.core.enuz.agent.RunStepStatusEnum;
import kd.ai.gai.core.service.agent.handler.AgentHandlerFactory;
import kd.ai.gai.core.trace.entity.BaseResult;
import kd.ai.gai.core.util.AgentThreadUtils;
import kd.bos.context.RequestContext;
import kd.bos.logging.Log;
import kd.bos.logging.LogFactory;
import kd.bos.service.KDDateFormatUtils;
import kd.bos.service.KDDateUtils;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;

/* loaded from: input_file:kd/ai/gai/core/service/agent/AgentMsgWatchWebSocketService.class */
public class AgentMsgWatchWebSocketService {
    private static final Log log = LogFactory.getLog(AgentMsgWatchWebSocketService.class);
    private static final int MAX_TIME_OUT = 1800000;
    private AgentMsgWatchWebSocketClient agentMsgWatchWebSocketClient;
    BlockingQueue<Object> messageQueue = new LinkedBlockingQueue();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:kd/ai/gai/core/service/agent/AgentMsgWatchWebSocketService$AgentMsgWatchWebSocketClient.class */
    public static class AgentMsgWatchWebSocketClient extends WebSocketClient {
        private static final Log log = LogFactory.getLog(AgentMsgWatchWebSocketClient.class);
        private static final String ERROR_FLAG = "ERROR:";
        BlockingQueue<Object> messageQueue;
        String chatSessionId;

        public AgentMsgWatchWebSocketClient(URI uri, BlockingQueue<Object> blockingQueue, String str) {
            super(uri);
            this.messageQueue = blockingQueue;
            this.chatSessionId = str;
        }

        public void onMessage(String str) {
            if (this.messageQueue.offer(str)) {
                return;
            }
            log.warn("chatSessionId:{} offer message: {} error", this.chatSessionId, str);
        }

        public void onMessage(ByteBuffer byteBuffer) {
            if (this.messageQueue.offer(byteBuffer)) {
                return;
            }
            log.warn("chatSessionId:{} offer message: bytes error", this.chatSessionId);
        }

        public void onError(Exception exc) {
            if (this.messageQueue.offer(ERROR_FLAG + exc.getMessage())) {
                return;
            }
            log.warn("chatSessionId:{} offer error to queue ", this.chatSessionId);
        }

        public void onOpen(ServerHandshake serverHandshake) {
            log.warn("chatSessionId:{} agentMsgWatch onOpen ", this.chatSessionId);
        }

        public void onClose(int i, String str, boolean z) {
            log.warn("chatSessionId:{} agentMsgWatch onClose ", this.chatSessionId);
        }
    }

    private void onMessageQueue(WsRequestParams wsRequestParams) {
        boolean z;
        String chatSessionId = getChatSessionId();
        RequiredParams requiredParams = wsRequestParams.getRequiredParams();
        String requestId = requiredParams.getRequestId();
        long time = KDDateUtils.now().getTime();
        while (KDDateUtils.now().getTime() - time < 1800000) {
            WsRequestParams wsRequestParams2 = null;
            try {
                Object poll = this.messageQueue.poll(100L, TimeUnit.MILLISECONDS);
                z = false;
                while (poll != null) {
                    wsRequestParams2 = null;
                    time = KDDateUtils.now().getTime();
                    String str = "";
                    if (poll instanceof String) {
                        wsRequestParams2 = processText(chatSessionId, (String) poll);
                        if (wsRequestParams2 != null) {
                            if (webSocketMsgIsEnd(wsRequestParams2)) {
                                z = true;
                            }
                            if (wsRequestParams2.isResponse()) {
                                str = wsRequestParams2.getRequiredParams().getRequestId();
                            }
                        }
                    } else {
                        str = processBinary((ByteBuffer) poll);
                    }
                    if (StringUtils.equalsIgnoreCase(requestId, str)) {
                        Date now = KDDateUtils.now();
                        log.info("【type:{}】, requestId:{} receive response with [{}, {}](cost:{}s)", new Object[]{requiredParams.getType(), requiredParams.getRequestId(), KDDateFormatUtils.getDateTimeFormat().format(new Date(time)), KDDateFormatUtils.getDateTimeFormat().format(now), Long.valueOf((now.getTime() - time) / 1000)});
                    }
                    poll = this.messageQueue.poll(100L, TimeUnit.MILLISECONDS);
                }
            } catch (Exception e) {
                log.error(String.format("【type:%s】, requestId:%s onMessage error : %s", requiredParams.getType(), requiredParams.getRequestId(), e.getMessage()), e);
                sendExceptionMsg(wsRequestParams2, chatSessionId, e.getMessage());
            }
            if (z) {
                log.info("toBreak {} with processWsRequestParams: {}", chatSessionId, wsRequestParams2 != null ? JSONObject.toJSONString(wsRequestParams2) : null);
                break;
            }
            continue;
        }
        closeWs(chatSessionId);
    }

    private void sendExceptionMsg(WsRequestParams wsRequestParams, String str, String str2) {
        if (wsRequestParams != null) {
            RequiredParams requiredParams = wsRequestParams.getRequiredParams();
            String replace = StringUtils.replace(requiredParams.getType(), WsRequestParams.REQUEST_PREFIX, WsRequestParams.RESPONSE_PREFIX);
            String format = String.format("执行接口异常：%s - %s", replace, str2);
            requiredParams.setType(replace);
            requiredParams.setGaiRequestContext(new GaiRequestContext(RequestContext.get()));
            wsRequestParams.setRequiredParams(requiredParams);
            wsRequestParams.setBizParams(BaseResult.fail(format));
            AgentMsgWatchWebSocketService agentMsgWatchWebSocketService = AgentConstants.agentMsgWatchWebSocketServiceMap.get(str);
            if (agentMsgWatchWebSocketService != null) {
                agentMsgWatchWebSocketService.sendMsg(str, wsRequestParams);
            }
        }
    }

    private WsRequestParams processText(String str, String str2) {
        log.info("Received WsMessage.text: {}", str2);
        WsRequestParams wsRequestParams = (WsRequestParams) JSON.parseObject(str2, WsRequestParams.class);
        log.info("Received WsMessage: {} - {}", wsRequestParams.getRequiredParams().getType(), str2);
        if (wsRequestParams.isRequest()) {
            wsRequestParams = AgentHandlerFactory.runHandler(str, wsRequestParams);
        }
        return wsRequestParams;
    }

    private boolean webSocketMsgIsEnd(WsRequestParams wsRequestParams) {
        return wsRequestParams != null && wsRequestParams.isResponse() && StringUtils.equalsIgnoreCase(wsRequestParams.getRequiredParams().getType(), AgentConstants.RESPONSE_SESSION_RUN_SAVE) && RunStepStatusEnum.isEndStatus(((BaseResult) wsRequestParams.getBizParams()).getDataAsJSONObject().getString(AgentConstants.RUN_STATUS));
    }

    private String processBinary(ByteBuffer byteBuffer) {
        String chatSessionId = getChatSessionId();
        String format = String.format("%s_%s", AgentConstants.SESSION_TEMP_FILE_PREFIX, chatSessionId);
        String str = SessionCache.get().get(format);
        FileInfo fileInfo = (FileInfo) JSONObject.parseObject(str, FileInfo.class);
        String requestId = fileInfo.getRequestId();
        log.info("processBinary: tempFileCacheKey:{}, requestId:{}, fileInfo:{}", new Object[]{format, requestId, str});
        fileInfo.setInputStream(byteBufferToInputStream(byteBuffer));
        FileInfo save = FileService.save(fileInfo);
        save.setInputStream(null);
        SessionCache.get().remove(format);
        AgentConstants.agentMsgWatchWebSocketServiceMap.get(chatSessionId).sendMsg(chatSessionId, new WsRequestParams(AgentConstants.RESPONSE_FILE_METADATA_SAVE, requestId, new BaseResult(save)));
        return requestId;
    }

    private InputStream byteBufferToInputStream(ByteBuffer byteBuffer) {
        byte[] bArr = new byte[byteBuffer.remaining()];
        byteBuffer.get(bArr);
        return new ByteArrayInputStream(bArr);
    }

    @NotNull
    private String getChatSessionId() {
        String path = this.agentMsgWatchWebSocketClient.getURI().getPath();
        return path.substring(path.lastIndexOf("/") + 1);
    }

    private void connectWs(String str, WsRequestParams wsRequestParams) {
        try {
            if (this.agentMsgWatchWebSocketClient == null || this.agentMsgWatchWebSocketClient.isClosed()) {
                this.agentMsgWatchWebSocketClient = new AgentMsgWatchWebSocketClient(new URI(String.format("%s/agent/msgwatch/%s", AgentServiceService.getAgentService("gai_agent").getServerUrl(), str)), this.messageQueue, str);
                if (this.agentMsgWatchWebSocketClient.connectBlocking(10L, TimeUnit.SECONDS)) {
                    log.info("chatSessionId {} create connect success", str);
                    if (wsRequestParams.isRequest()) {
                        log.info("requiredParams.type {} : {} start onMessageQueue", str, wsRequestParams.getRequiredParams().getType());
                        AgentThreadUtils.execute(() -> {
                            onMessageQueue(wsRequestParams);
                        });
                    }
                }
            } else {
                log.info("chatSessionId {} connect existed and avail.", str);
            }
        } catch (Throwable th) {
            log.error(String.format("error : %s", th.getMessage()), th);
        }
    }

    private void closeWs(String str) {
        if (this.agentMsgWatchWebSocketClient == null || !this.agentMsgWatchWebSocketClient.isOpen()) {
            return;
        }
        try {
            try {
                log.info("closeWs chatSessionId:{}", str);
                this.agentMsgWatchWebSocketClient.closeBlocking();
                log.info("remove agentWebSocketRunner chatSessionId:{}", str);
                this.agentMsgWatchWebSocketClient = null;
                AgentConstants.agentMsgWatchWebSocketServiceMap.remove(str);
                RunService.expiredRunOrRunStep(str);
            } catch (Exception e) {
                log.error(String.format("error : %s", e.getMessage()), e);
                log.info("remove agentWebSocketRunner chatSessionId:{}", str);
                this.agentMsgWatchWebSocketClient = null;
                AgentConstants.agentMsgWatchWebSocketServiceMap.remove(str);
                RunService.expiredRunOrRunStep(str);
            }
        } catch (Throwable th) {
            log.info("remove agentWebSocketRunner chatSessionId:{}", str);
            this.agentMsgWatchWebSocketClient = null;
            AgentConstants.agentMsgWatchWebSocketServiceMap.remove(str);
            RunService.expiredRunOrRunStep(str);
            throw th;
        }
    }

    public void sendMsg(String str, WsRequestParams wsRequestParams) {
        try {
            String jSONString = wsRequestParams.toJSONString();
            log.info("chatSessionId:{}, sendMsg: {}", str, jSONString);
            connectWs(str, wsRequestParams);
            this.agentMsgWatchWebSocketClient.send(jSONString);
        } catch (Throwable th) {
            log.error(String.format("sendMsg error : %s", th.getMessage()), th);
        }
    }
}
