/*
 * Decompiled with CFR 0.152.
 */
package kd.ai.gai.core.service.rerank;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import kd.ai.gai.core.domain.dto.RerankChunk;
import kd.ai.gai.core.domain.llm.base.ResultRerank;
import kd.ai.gai.core.engine.Errors;
import kd.ai.gai.core.engine.message.RerankMessage;
import kd.ai.gai.core.service.rerank.RerankService;
import kd.bos.context.RequestContext;
import kd.bos.exception.ErrorCode;
import kd.bos.logging.Log;
import kd.bos.logging.LogFactory;
import kd.bos.servicehelper.DispatchServiceHelper;
import kd.bos.util.StringUtils;

public class BaiduRerankerService
extends RerankService {
    private Log logger = LogFactory.getLog(BaiduRerankerService.class);
    private int queryLimit = 1600;
    private int chunkListLimit = 200;
    private int chunklimit = 4096;

    @Override
    public ResultRerank syncRerank(RerankMessage rerankMessage) {
        int topN;
        String query;
        JSONObject modelBody = new JSONObject();
        List<RerankChunk> rerankChunks = rerankMessage.getChunkList();
        ErrorCode errorCode = this.checkParam(rerankChunks, query = rerankMessage.getInputQuery(), topN = rerankMessage.getTop_n());
        if (errorCode.getCode() != "0") {
            return new ResultRerank(errorCode.getCode(), errorCode.getMessage());
        }
        ArrayList<String> documentList = new ArrayList<String>(rerankChunks.size());
        HashMap<String, List<String>> chunkMap = new HashMap<String, List<String>>(rerankChunks.size());
        for (RerankChunk rerank : rerankChunks) {
            String documentText = rerank.getChunk();
            String chunkId = String.valueOf(rerank.getChunkId());
            documentList.add(documentText);
            chunkMap.computeIfAbsent(documentText, k -> new ArrayList()).add(chunkId);
        }
        modelBody.put("query", (Object)query);
        modelBody.put("documents", documentList);
        if (topN != 0) {
            modelBody.put("top_n", (Object)topN);
        }
        String userName = RequestContext.get().getUserName();
        HashMap<String, String> contextMap = new HashMap<String, String>();
        contextMap.put("stream", "false");
        this.logger.info("\u7528\u6237({})\u5f00\u59cb\u8c03\u7528Rerank\u6a21\u578b,\u53c2\u6570\uff1a{}", (Object)userName, (Object)modelBody);
        long start = System.currentTimeMillis();
        Map result = (Map)DispatchServiceHelper.invokeBizService((String)"ai", (String)"aicc", (String)"AiccService", (String)"syncService", (Object[])new Object[]{contextMap, rerankMessage.getServiceNumber(), modelBody.toJSONString()});
        this.logger.info("\u8c03\u7528Rerank\u6a21\u578b\u540c\u6b65\u670d\u52a1\u6267\u884c\u7ed3\u679c {} \uff0c\u8017\u65f6 {}", JSON.toJSON((Object)result), (Object)(System.currentTimeMillis() - start));
        List<RerankChunk> data = this.getResult(documentList, chunkMap, (String)result.get("result"));
        return new ResultRerank((String)result.get("errorCode"), (String)result.get("message"), data, (String)result.get("id"));
    }

    public ErrorCode checkParam(List<RerankChunk> rerankChunks, String inputQuery, int topN) {
        if (rerankChunks == null) {
            String errorMsg = "\u767e\u5ea6Reanker\u6a21\u578b-\u8bf7\u6c42chunk\u5217\u8868\u4e0d\u80fd\u4e3a\u7a7a";
            this.logger.error(errorMsg);
            return Errors.rerankParamError(errorMsg);
        }
        if (StringUtils.isEmpty((String)inputQuery)) {
            String errorMsg = "\u767e\u5ea6Reanker\u6a21\u578b-\u8bf7\u6c42inputQuery\u53c2\u6570\u4e0d\u80fd\u4e3a\u7a7a";
            this.logger.error(errorMsg);
            return Errors.rerankParamError(errorMsg);
        }
        int chunkListSize = rerankChunks.size();
        if (chunkListSize == 0) {
            String errorMsg = "\u767e\u5ea6Reanker\u6a21\u578b-\u8bf7\u6c42chunk\u5217\u8868\u4e0d\u80fd\u4e3a\u7a7a";
            this.logger.error(errorMsg);
            return Errors.rerankParamError(errorMsg);
        }
        int queryLength = inputQuery.length();
        if (queryLength > this.queryLimit) {
            String errorMsg = "Reanker\u6a21\u578b-query\u53c2\u6570\u957f\u5ea6\u8d85\u51fa\u6a21\u578b\u672c\u8eab\u9650\u5236";
            this.logger.error(errorMsg);
            return Errors.rerankParamError(errorMsg);
        }
        if (topN > chunkListSize) {
            String errorMsg = "Reanker\u6a21\u578b-topK\u503c\u4e0d\u80fd\u8d85\u8fc7chunk\u5757\u6570\u91cf\uff1a[" + chunkListSize + "]";
            this.logger.error(errorMsg);
            return Errors.rerankParamError(errorMsg);
        }
        if (chunkListSize > this.chunkListLimit) {
            String errorMsg = "Reanker\u6a21\u578b-chunk\u5757\u6570\u91cf\u8d85\u51fa\u6a21\u578b\u672c\u8eab\u9650\u5236\u5757\u6570\u91cf\uff1a[" + this.chunkListLimit + "]";
            this.logger.error(errorMsg);
            return Errors.rerankParamError(errorMsg);
        }
        for (RerankChunk rerankChunk : rerankChunks) {
            String chunkStr = rerankChunk.getChunk();
            if (StringUtils.isEmpty((String)chunkStr)) {
                String errorMsg = "Reanker\u6a21\u578b-\u6bcf\u4e2achunk\u5757\u90fd\u4e0d\u80fd\u4e3anull\u6216\u8005\u7a7a\u5b57\u7b26\u4e32";
                this.logger.error(errorMsg);
                return Errors.rerankParamError(errorMsg);
            }
            int chunkLength = chunkStr.length();
            if (chunkLength <= this.chunklimit) continue;
            String errorMsg = "Reanker\u6a21\u578b-\u5b58\u5728chunk\u5757\u957f\u5ea6\u8d85\u51fa\u6a21\u578b\u672c\u8eab\u9650\u5236";
            this.logger.error(errorMsg);
            return Errors.rerankParamError(errorMsg);
        }
        return Errors.OK;
    }

    @Override
    public List<RerankChunk> getResult(List<String> documentList, Map<String, List<String>> chunkMap, String modelResult) {
        ArrayList<RerankChunk> resultChunk = new ArrayList<RerankChunk>();
        if (StringUtils.isNotEmpty((String)modelResult)) {
            JSONObject resultJson = JSON.parseObject((String)modelResult);
            JSONArray resultArray = resultJson.getJSONArray("results");
            if (resultArray == null) {
                return null;
            }
            for (JSONObject iterJson : resultArray) {
                String document = iterJson.getString("document");
                Float relevanceScore = iterJson.getFloat("relevance_score");
                List<String> chunkIds = chunkMap.get(document);
                if (chunkIds.size() == 0) continue;
                RerankChunk rerankChunk = new RerankChunk();
                rerankChunk.setChunkId(chunkIds.get(0));
                rerankChunk.setChunk(document);
                rerankChunk.setRelevanceScore(relevanceScore);
                resultChunk.add(rerankChunk);
                chunkIds.remove(0);
            }
            return resultChunk;
        }
        return null;
    }
}

