/*
 * Decompiled with CFR 0.152.
 */
package kd.bos.gptas.adapter.rerank.kdai.service;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import kd.bos.context.RequestContext;
import kd.bos.dataentity.serialization.SerializationUtils;
import kd.bos.gptas.api.RerankService;
import kd.bos.gptas.api.rerank.RerankResult;
import kd.bos.gptas.api.vector.Chunk;
import kd.bos.gptas.common.GptasErrorCode;
import kd.bos.gptas.common.GptasException;
import kd.bos.gptas.common.embedding.exception.EmbeddingException;
import kd.bos.gptas.openapi.OpenApiClient;
import kd.bos.gptas.servicehelper.AIServiceProxy;
import kd.bos.logging.Log;
import kd.bos.logging.LogFactory;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;

public class KDAIRerankService
implements RerankService {
    private static final Log log = LogFactory.getLog(KDAIRerankService.class);
    private static final String AI_SERVICE = "ai";
    private static final String AI_MODULE = "aicc";
    private static final String AI_CLASS = "AiccService";
    private static final String AI_METHOD = "syncService";
    private static OpenApiClient openApiClient;
    private final String rerankNumber;
    public int MAX_BATCH_SIZE = 64;

    public KDAIRerankService(String rerankNumber) {
        this.rerankNumber = rerankNumber;
    }

    public List<RerankResult> rerank(String content, List<Chunk> chunks, int topK) {
        if (CollectionUtils.isEmpty(chunks)) {
            return new ArrayList<RerankResult>(1);
        }
        if (chunks.size() > this.MAX_BATCH_SIZE) {
            ArrayList<RerankResult> list = new ArrayList<RerankResult>(16);
            for (int i = 0; i < chunks.size(); i += this.MAX_BATCH_SIZE) {
                int end = i + this.MAX_BATCH_SIZE;
                if (end > chunks.size()) {
                    end = chunks.size();
                }
                List<RerankResult> subChunks = this.getRerankResults(content, chunks.subList(i, end), topK);
                list.addAll(subChunks);
            }
            return list;
        }
        return this.getRerankResults(content, chunks, topK);
    }

    private List<RerankResult> getRerankResults(String content, List<Chunk> chunks, int topK) {
        Optional<Chunk> emptyChunk = chunks.stream().filter(chunk -> StringUtils.isBlank((CharSequence)chunk.getContent())).findFirst();
        if (emptyChunk.isPresent()) {
            throw new RuntimeException("chunk content mustn't be empty");
        }
        List<String> documents = chunks.stream().map(Chunk::getContent).collect(Collectors.toList());
        topK = Math.min(topK, documents.size());
        try {
            Map<String, Object> params = this.parseRequest(content, topK, documents);
            HashMap<String, String> context = new HashMap<String, String>();
            context.put("stream", "false");
            String userName = this.getCurrentUserName();
            log.info("User {} calling embedding service {} with params: {}", new Object[]{userName, this.rerankNumber, JSON.toJSON(params)});
            long startTime = System.currentTimeMillis();
            Map<String, String> result = AIServiceProxy.invokeAICCService(context, this.rerankNumber, SerializationUtils.toJsonString(params));
            long duration = System.currentTimeMillis() - startTime;
            log.info("Rerank completed for user {} using {} in {}ms", new Object[]{userName, this.rerankNumber, duration});
            this.validateResponse(result);
            List<RerankResult> rerankResults = this.parseResponse(result.get("result"));
            for (RerankResult rerankResult : rerankResults) {
                rerankResult.setChunk(chunks.get(rerankResult.getOriginIndex()));
            }
            return rerankResults;
        }
        catch (Exception e) {
            log.error("Embedding failed", (Throwable)e);
            throw new GptasException(GptasErrorCode.UNKNOWN_ERROR, "Rerank failed: " + e.getMessage(), (Throwable)e);
        }
    }

    protected Map<String, Object> parseRequest(String content, int topK, List<String> documents) {
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("query", content);
        params.put("documents", documents);
        params.put("top_n", topK);
        return params;
    }

    private void validateResponse(Map<String, String> result) throws EmbeddingException {
        String errorCode = result.get("errorCode");
        if (!"0".equals(errorCode)) {
            String errMessage = result.get("message");
            throw new GptasException(GptasErrorCode.UNKNOWN_ERROR, String.format("Rerank service error (code: %s): %s, result: %s", errorCode, errMessage, result.toString()));
        }
    }

    protected List<RerankResult> parseResponse(String response) {
        JSONObject contentJson = JSON.parseObject((String)response);
        if (!contentJson.containsKey((Object)"id")) {
            throw new GptasException(GptasErrorCode.UNKNOWN_ERROR, "Invalid Baidu rerank response: " + response);
        }
        JSONArray jsonArray = contentJson.getJSONArray("results");
        if (jsonArray == null || jsonArray.size() == 0) {
            return new ArrayList<RerankResult>(1);
        }
        ArrayList<RerankResult> rerankResults = new ArrayList<RerankResult>(jsonArray.size());
        for (int i = 0; i < jsonArray.size(); ++i) {
            Float relevanceScore = jsonArray.getJSONObject(i).getFloat("relevance_score");
            Integer index = jsonArray.getJSONObject(i).getInteger("index");
            RerankResult rerankResult = new RerankResult();
            rerankResult.setRelevanceScore(relevanceScore.floatValue());
            rerankResult.setOriginIndex(index.intValue());
            rerankResult.setRank(i + 1);
            rerankResults.add(rerankResult);
        }
        return rerankResults;
    }

    private String getCurrentUserName() {
        return RequestContext.get().getUserName();
    }
}

