/*
 * Decompiled with CFR 0.152.
 */
package kd.bos.gptas.qa.service;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import kd.bos.dataentity.serialization.SerializationUtils;
import kd.bos.gptas.api.LLMService;
import kd.bos.gptas.api.llm.MarkedMessageStream;
import kd.bos.gptas.api.llm.Marker;
import kd.bos.gptas.api.llm.MarkerContent;
import kd.bos.gptas.api.llm.MarkerPair;
import kd.bos.gptas.api.llm.MessageStream;
import kd.bos.gptas.qa.model.QAAnswer;
import kd.bos.gptas.qa.model.QAChatHistory;
import kd.bos.gptas.qa.model.QAChatMessage;
import kd.bos.gptas.qa.model.QAChunk;
import kd.bos.gptas.qa.model.QAPrompt;
import kd.bos.gptas.qa.service.ChunkProvider;
import kd.bos.gptas.qa.service.DefaultChunkProvider;
import kd.bos.gptas.qa.service.QAServiceLog;

public class QAService {
    private QAServiceLog logger = new QAServiceLog();
    private ChunkProvider chunkProvider = new DefaultChunkProvider();
    private String firstPromptTemplate = "FirstPrompt.md";
    private String historyPromptTemplate = "HistoryPrompt.md";
    private String historyChunksPromptTemplate = "HistoryChunksPrompt.md";
    private int topK = 5;
    private boolean enableReSeq = false;
    private boolean enableRerank = false;
    private String rerankNumber = "";
    private boolean enableCommonSearch = true;

    public void initPromptTemplate(String firstPromptTemplate, String historyPromptTemplate, String historyChunksPromptTemplate) {
        this.firstPromptTemplate = firstPromptTemplate;
        this.historyPromptTemplate = historyPromptTemplate;
        this.historyChunksPromptTemplate = historyChunksPromptTemplate;
    }

    public String getRerankNumber() {
        return this.rerankNumber;
    }

    public void setRerankNumber(String rerankNumber) {
        this.rerankNumber = rerankNumber;
    }

    public boolean isEnableRerank() {
        return this.enableRerank;
    }

    public void setEnableRerank(boolean enableRerank) {
        this.enableRerank = enableRerank;
    }

    public boolean isEnableReSeq() {
        return this.enableReSeq;
    }

    public void setEnableReSeq(boolean enableReSeq) {
        this.enableReSeq = enableReSeq;
    }

    public boolean isEnableCommonSearch() {
        return this.enableCommonSearch;
    }

    public void setEnableCommonSearch(boolean enableCommonSearch) {
        this.enableCommonSearch = enableCommonSearch;
    }

    public void setChunkProvider(ChunkProvider chunkProvider) {
        if (chunkProvider == null) {
            chunkProvider = new DefaultChunkProvider();
        }
        this.chunkProvider = chunkProvider;
    }

    public void setDebugLog(QAServiceLog logger) {
        this.logger = logger;
    }

    public Iterator<QAAnswer> qa(String chatSessionId, String question, List<String> formIds) {
        LLMService llmService = LLMService.create();
        QAChatHistory qaChatHistory = new QAChatHistory(chatSessionId);
        List<QAChatMessage> chatMessageList = qaChatHistory.getChatMessageList();
        if (chatMessageList.isEmpty()) {
            this.logger.info("\u7b2c\u4e00\u8f6e\u4f1a\u8bdd");
            QAPrompt qaPrompt = this.buildFirstPrompt(question, formIds, qaChatHistory);
            return QAService.runLLM(llmService, question, qaPrompt, qaChatHistory);
        }
        QAPrompt qaPrompt = this.buildPrompt(this.historyPromptTemplate, question, qaChatHistory);
        ArrayList<Marker> markerList = new ArrayList<Marker>(16);
        markerList.add(new Marker("answer", true));
        markerList.add(new Marker("retrieve", true));
        markerList.add(new Marker("irrelhis", true));
        MessageStream messageStream = llmService.llm(qaPrompt.getPrompt(), null);
        MarkedMessageStream markerAwareMessageStream = llmService.markerAware(messageStream, markerList);
        String tag = "";
        while (markerAwareMessageStream.hasNext()) {
            MarkerPair next = (MarkerPair)markerAwareMessageStream.next();
            if (!next.isMarked()) continue;
            switch (next.getMarkerContent().getMarker().getStartTag()) {
                case "answer": {
                    this.logger.info("answer \u76f4\u63a5\u56de\u7b54\u7ed3\u679c\u3002");
                    QAChatMessage qaChatMessage = new QAChatMessage();
                    qaChatMessage.setQuestion(question);
                    qaChatMessage.setPrompt(qaPrompt.getPrompt());
                    qaChatHistory.getChatMessageList().add(qaChatMessage);
                    return new QAAnswerIterator(messageStream, qaChatHistory, (MarkerPair<String, MarkerContent>)next);
                }
                case "retrieve": 
                case "irrelhis": {
                    tag = next.getMarkerContent().getMarker().getStartTag();
                }
            }
            if (tag.isEmpty()) continue;
            break;
        }
        List allMarkerContents = markerAwareMessageStream.getAllMarkerContents();
        if (tag.equals("retrieve")) {
            this.logger.info(tag + " \u8865\u5145\u5411\u91cf\u67e5\u8be2\u7ed3\u679c\u3002");
            String retriveQuestion = ((MarkerContent)allMarkerContents.get(0)).getContent();
            List<QAChunk> chunkList = this.getChunkList(formIds, retriveQuestion, true);
            qaChatHistory.addChunkList(chunkList);
            qaPrompt = this.buildPrompt(this.historyChunksPromptTemplate, question, qaChatHistory);
            return QAService.runLLM(llmService, question, qaPrompt, qaChatHistory);
        }
        if (tag.equals("irrelhis") || tag.isEmpty()) {
            this.logger.info(tag + " \u65e0\u5173\u5386\u53f2\uff0c\u6e05\u9664\u5386\u53f2\u7eaa\u5f55\u3002");
            qaChatHistory.clearChatMessage();
            qaPrompt = this.buildFirstPrompt(question, formIds, qaChatHistory);
            return QAService.runLLM(llmService, question, qaPrompt, qaChatHistory);
        }
        return new QAAnswerNoData();
    }

    private QAPrompt buildFirstPrompt(String question, List<String> formIds, QAChatHistory qaChatHistory) {
        List<QAChunk> chunkList = this.getChunkList(formIds, question, false);
        qaChatHistory.addChunkList(chunkList);
        QAPrompt qaPrompt = this.buildPrompt(this.firstPromptTemplate, question, qaChatHistory);
        this.logger.info("\u751f\u6210\u63d0\u793a\u8bcd\uff1a" + qaPrompt.getPrompt());
        return qaPrompt;
    }

    private static QAAnswerIterator runLLM(LLMService llmService, String question, QAPrompt qaPrompt, QAChatHistory qaChatHistory) {
        QAChatMessage qaChatMessage = new QAChatMessage();
        qaChatMessage.setQuestion(question);
        qaChatMessage.setPrompt(qaPrompt.getPrompt());
        qaChatHistory.getChatMessageList().add(qaChatMessage);
        MessageStream messageStream = llmService.llm(qaPrompt.getPrompt(), null);
        return new QAAnswerIterator(messageStream, qaChatHistory);
    }

    private List<QAChunk> getChunkList(List<String> formIds, String question, boolean isRetrieve) {
        if (isRetrieve) {
            List<QAChunk> chunkList = ChunkProvider.createDefault().getChunkList(formIds, question, this.topK, this.enableReSeq, this.enableRerank, this.rerankNumber, this.enableCommonSearch);
            this.chunkProvider.onGetChunk(chunkList);
            return chunkList;
        }
        return this.chunkProvider.getChunkList(formIds, question, this.topK, this.enableReSeq, this.enableRerank, this.rerankNumber, this.enableCommonSearch);
    }

    private QAPrompt buildPrompt(String promptTemplate, String question, QAChatHistory qaChatHistory) {
        List<QAChatMessage> chatMessageList = qaChatHistory.getChatMessageList();
        StringBuilder history = new StringBuilder();
        int i = 0;
        for (QAChatMessage qaChatMessage : chatMessageList) {
            history.append("\n\u95ee\u9898").append(i + 1).append(":").append(qaChatMessage.getQuestion()).append("\u56de\u7b54").append(i + 1).append(qaChatMessage.getAnswer()).append("\n");
            ++i;
        }
        StringBuilder chunks = new StringBuilder();
        ArrayList list = new ArrayList(16);
        for (QAChunk qaChunk : qaChatHistory.getQaChunkList()) {
            HashMap<String, String> map = new HashMap<String, String>(16);
            map.put("doc_id", qaChunk.getFormId() + "_" + String.valueOf(qaChunk.getId()));
            map.put("content", qaChunk.getContent());
            list.add(map);
        }
        chunks.append(SerializationUtils.toJsonString(list)).append("\n");
        String prompt = QAPrompt.getPromptTemplateFromFile("prompt/" + promptTemplate).replace("{{question}}", question).replace("{{chunks}}", chunks).replace("{{history}}", history);
        return new QAPrompt(prompt);
    }

    public int getTopK() {
        return this.topK;
    }

    public void setTopK(int topK) {
        this.topK = topK;
    }

    static class QAAnswerIterator
    implements Iterator<QAAnswer> {
        private static final String REFTAG = "########";
        private final MarkedMessageStream stream;
        private MarkerPair<String, MarkerContent> firstContent = null;
        private final QAChatHistory qaChatHistory;
        private boolean hasNext = true;
        private boolean isFinished = false;

        public QAAnswerIterator(MessageStream stream, QAChatHistory qaChatHistory) {
            this.stream = this.wrapperStream(stream);
            this.qaChatHistory = qaChatHistory;
        }

        public QAAnswerIterator(MessageStream stream, QAChatHistory qaChatHistory, MarkerPair<String, MarkerContent> firstContent) {
            this.stream = this.wrapperStream(stream);
            this.firstContent = firstContent;
            this.qaChatHistory = qaChatHistory;
        }

        private MarkedMessageStream wrapperStream(MessageStream messageStream) {
            ArrayList<Marker> markerList = new ArrayList<Marker>(1);
            Marker marker = new Marker(REFTAG, false);
            markerList.add(marker);
            return LLMService.create().markerAware(messageStream, markerList);
        }

        @Override
        public boolean hasNext() {
            if (this.firstContent != null) {
                return true;
            }
            this.hasNext = this.stream.hasNext();
            if (!this.hasNext) {
                return !this.isFinished;
            }
            return true;
        }

        @Override
        public QAAnswer next() {
            if (this.firstContent != null) {
                QAAnswer qaAnswer = new QAAnswer();
                qaAnswer.setAnswer((String)this.firstContent.getContent());
                this.qaChatHistory.appendAnswer((String)this.firstContent.getContent());
                this.firstContent = null;
                return qaAnswer;
            }
            if (this.hasNext) {
                MarkerPair next = (MarkerPair)this.stream.next();
                QAAnswer qaAnswer = new QAAnswer();
                qaAnswer.setAnswer((String)next.getContent());
                this.qaChatHistory.appendAnswer((String)next.getContent());
                return qaAnswer;
            }
            QAAnswer qaAnswer = this.createEndQAAnswer();
            this.qaChatHistory.save();
            this.isFinished = true;
            return qaAnswer;
        }

        private QAAnswer createEndQAAnswer() {
            QAAnswer qaAnswer = new QAAnswer();
            qaAnswer.setAnswer("");
            List allMarkerContents = this.stream.getAllMarkerContents();
            ArrayList<String> refChunkList = new ArrayList<String>(16);
            allMarkerContents.forEach(o -> {
                if (o.getMarker().getStartTag().equals(REFTAG)) {
                    refChunkList.add(o.getContent());
                }
            });
            qaAnswer.setRefChunkList(refChunkList);
            qaAnswer.setEnd(true);
            return qaAnswer;
        }
    }

    static class QAAnswerNoData
    implements Iterator<QAAnswer> {
        boolean hasNext = true;

        QAAnswerNoData() {
        }

        @Override
        public boolean hasNext() {
            return this.hasNext;
        }

        @Override
        public QAAnswer next() {
            if (this.hasNext) {
                this.hasNext = false;
                QAAnswer qaAnswer = new QAAnswer();
                qaAnswer.setAnswer("no data");
                qaAnswer.setEnd(true);
                return qaAnswer;
            }
            return null;
        }
    }
}

