/*
 * Decompiled with CFR 0.152.
 */
package kd.bos.gptas.autoact.llm.baidu;

import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import kd.bos.gptas.autoact.exception.ChatException;
import kd.bos.gptas.autoact.llm.ChatModel;
import kd.bos.gptas.autoact.llm.baidu.BaiduAccessTokenManager;
import kd.bos.gptas.autoact.message.AIMessage;
import kd.bos.gptas.autoact.message.ChatMessage;
import kd.bos.gptas.autoact.message.ChatMessageType;
import kd.bos.gptas.autoact.message.IterMessageStream;
import kd.bos.gptas.autoact.message.read.StreamLineReader;
import kd.bos.gptas.autoact.model.Tool;
import kd.bos.gptas.autoact.monitor.Cost;
import kd.bos.gptas.autoact.output.TokenUsage;
import kd.bos.gptas.autoact.util.JsonUtil;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;

public class BaiduERNIEBot40Model
implements ChatModel {
    public static final String CONTENT_TYPE = "Content-Type";
    public static final String TEXT_EVENT_STREAM = "text/event-stream";
    private String apiKey;
    private String securityKey;
    private long readTimeout = 20000L;

    @Override
    public String name() {
        return "ERNIEBot40";
    }

    @Override
    public kd.bos.gptas.autoact.output.Response<AIMessage> generate(List<ChatMessage> messages) {
        return this.doGenerate(messages, true);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private kd.bos.gptas.autoact.output.Response<AIMessage> doGenerate(List<ChatMessage> messages, boolean stream) {
        try (Cost cost = Cost.trace("BaiduERNIEBot40Model.generate");){
            OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().readTimeout(this.readTimeout, TimeUnit.MILLISECONDS).build();
            MediaType mediaType = MediaType.parse((String)"application/json");
            String requestString = this.toRequestJsonString(messages, stream);
            RequestBody body = RequestBody.create((MediaType)mediaType, (String)requestString);
            Request request = new Request.Builder().url("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + this.getAccessToken(HTTP_CLIENT)).method("POST", body).addHeader(CONTENT_TYPE, "application/json").build();
            Function<String, kd.bos.gptas.autoact.output.Response> resultConsumer = retBody -> {
                JSONObject jsonObject = JsonUtil.parseObject(retBody);
                if (jsonObject.get((Object)"error_code") != null && jsonObject.get((Object)"error_msg") != null) {
                    if (110 == jsonObject.getInteger("error_code")) {
                        BaiduAccessTokenManager.clearAccessToken(this.apiKey, this.securityKey);
                        return this.generate(messages);
                    }
                    throw new ChatException(jsonObject.get((Object)"error_msg") + "(error_code=" + jsonObject.get((Object)"error_code") + ")");
                }
                String result = jsonObject.getString("result");
                kd.bos.gptas.autoact.output.Response<AIMessage> response = new kd.bos.gptas.autoact.output.Response<AIMessage>(new AIMessage(result));
                response.setFinishReason(jsonObject.getString("finish_reason"));
                JSONObject jsonUsage = jsonObject.getJSONObject("usage");
                if (jsonUsage != null) {
                    TokenUsage tokenUsage = response.getUsage();
                    tokenUsage.setPromptTokens(jsonUsage.getInteger("prompt_tokens"));
                    tokenUsage.setCompletionTokens(jsonUsage.getInteger("completion_tokens"));
                    tokenUsage.setTotalTokens(jsonUsage.getInteger("total_tokens"));
                }
                return response;
            };
            Response response = HTTP_CLIENT.newCall(request).execute();
            String contentType = response.header(CONTENT_TYPE);
            if (contentType != null && contentType.contains(TEXT_EVENT_STREAM)) {
                StreamLineReader lineReader = new StreamLineReader(response.body().byteStream());
                kd.bos.gptas.autoact.output.Response<AIMessage> response2 = new kd.bos.gptas.autoact.output.Response<AIMessage>(new AIMessage(new IterMessageStream(lineReader, (AutoCloseable)response, data -> ((AIMessage)((kd.bos.gptas.autoact.output.Response)resultConsumer.apply(data.replaceFirst("data: ", ""))).getResult()).getMessageText())));
                return response2;
            }
            String retBody2 = response.body().string();
            response.close();
            kd.bos.gptas.autoact.output.Response response3 = resultConsumer.apply(retBody2);
            return response3;
        }
        catch (Exception e) {
            throw ChatException.asChatException(e);
        }
    }

    private String toRequestJsonString(List<ChatMessage> messages, boolean stream) {
        HashMap<String, Serializable> ret = new HashMap<String, Serializable>();
        ret.put("disable_search", Boolean.FALSE);
        ret.put("enable_citation", Boolean.FALSE);
        ret.put("stream", Boolean.valueOf(stream));
        ArrayList msgItemList = new ArrayList(messages.size());
        boolean lastIsSystem = false;
        String systemMessage = null;
        for (ChatMessage message : messages) {
            ChatMessageType type = message.getType();
            switch (type) {
                case SYSTEM: {
                    systemMessage = message.getMessage();
                    break;
                }
                case USER: {
                    HashMap<String, String> itemMap = new HashMap<String, String>(4);
                    itemMap.put("role", "user");
                    if (lastIsSystem) {
                        itemMap.put("content", systemMessage + "\n" + message.getMessage());
                    } else {
                        itemMap.put("content", message.getMessage());
                    }
                    msgItemList.add(itemMap);
                    break;
                }
                case AI: 
                case TOOL: {
                    HashMap<String, String> itemMap = new HashMap(4);
                    itemMap.put("role", "assistant");
                    itemMap.put("content", message.getMessage());
                    msgItemList.add(itemMap);
                    break;
                }
                default: {
                    throw new ChatException("Unsupported message type: " + (Object)((Object)type));
                }
            }
            lastIsSystem = type == ChatMessageType.SYSTEM;
        }
        ret.put("messages", msgItemList);
        return JsonUtil.toJSONString(ret, new SerializerFeature[0]);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private String getAccessToken(OkHttpClient HTTP_CLIENT) {
        String accessToken = BaiduAccessTokenManager.getAccessToken(this.apiKey, this.securityKey);
        if (accessToken != null) {
            return accessToken;
        }
        try (Cost cost = Cost.trace("BaiduERNIEBot40Model.getAccessToken");){
            MediaType mediaType = MediaType.parse((String)"application/x-www-form-urlencoded");
            RequestBody body = RequestBody.create((MediaType)mediaType, (String)("grant_type=client_credentials&client_id=" + this.apiKey + "&client_secret=" + this.securityKey));
            Request request = new Request.Builder().url("https://aip.baidubce.com/oauth/2.0/token").method("POST", body).addHeader(CONTENT_TYPE, "application/x-www-form-urlencoded").build();
            Response response = HTTP_CLIENT.newCall(request).execute();
            JSONObject jsonObject = JsonUtil.parseObject(response.body().string());
            if (jsonObject.get((Object)"error") != null) {
                throw new ChatException(jsonObject.get((Object)"error") + ": " + jsonObject.get((Object)"error_description"));
            }
            long expiredTime = System.currentTimeMillis() + (long)((double)((long)jsonObject.getInteger("expires_in").intValue() * 1000L) * 0.9);
            accessToken = jsonObject.getString("access_token");
            BaiduAccessTokenManager.setAccessToken(this.apiKey, this.securityKey, accessToken, expiredTime);
            String string = accessToken;
            return string;
        }
        catch (IOException e) {
            throw new ChatException("BaiduChat getAccessToken error: " + e.getMessage(), e);
        }
    }

    @Override
    public kd.bos.gptas.autoact.output.Response<AIMessage> generate(List<ChatMessage> messages, List<Tool> tools) {
        return ChatModel.super.generate(messages, tools);
    }

    public static BaiduERNIEBot40ChatBuilder builder() {
        return new BaiduERNIEBot40ChatBuilder();
    }

    public static class BaiduERNIEBot40ChatBuilder {
        private String apiKey;
        private String securityKey;
        private long readTimeout;

        public BaiduERNIEBot40Model build() {
            BaiduERNIEBot40Model chat = new BaiduERNIEBot40Model();
            chat.apiKey = this.apiKey;
            chat.securityKey = this.securityKey;
            chat.readTimeout = this.readTimeout;
            return chat;
        }

        public BaiduERNIEBot40ChatBuilder key(String apiKey, String securityKey) {
            this.apiKey = apiKey;
            this.securityKey = securityKey;
            return this;
        }

        public BaiduERNIEBot40ChatBuilder readTimeout(long readTimeout) {
            this.readTimeout = readTimeout;
            return this;
        }
    }
}

