/*
 * Decompiled with CFR 0.152.
 */
package kd.bos.gptas.common;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import kd.bos.context.RequestContext;
import kd.bos.gptas.api.VectorService;
import kd.bos.gptas.api.km.split.SplitConfig;
import kd.bos.gptas.api.vector.Chunk;
import kd.bos.gptas.api.vector.EmbeddingModel;
import kd.bos.gptas.api.vector.VectorResult;
import kd.bos.gptas.api.vector.VectorTask;
import kd.bos.gptas.common.EmbeddingCompleteListener;
import kd.bos.gptas.common.KnowledgeVectorStorePlugin;
import kd.bos.gptas.common.VectorStorePlugin;
import kd.bos.gptas.common.embedding.EmbeddingFactory;
import kd.bos.gptas.common.embedding.exception.EmbeddingException;
import kd.bos.gptas.common.embedding.model.EmbeddingVector;
import kd.bos.gptas.common.embedding.service.EmbeddingService;
import kd.bos.gptas.common.vectordb.VectorStoreFactory;
import kd.bos.gptas.common.vectordb.model.VectorChunk;
import kd.bos.gptas.common.vectordb.model.VectorQuery;
import kd.bos.gptas.common.vectordb.service.VectorStoreService;
import kd.bos.gptas.common.vectortask.VectorTaskManager;
import kd.bos.gptas.utils.SystemPropertyUtils;
import kd.bos.logging.Log;
import kd.bos.logging.LogFactory;
import kd.bos.threads.ThreadPool;
import kd.bos.threads.ThreadPools;
import kd.bos.util.Pair;
import org.apache.commons.collections4.CollectionUtils;

public class VectorServiceImpl
implements VectorService {
    private static final ThreadPool threadPools = ThreadPools.newFixedThreadPool((String)"kd.bos.gptas.common.VectorServiceImpl", (int)20);
    private static final Log logger = LogFactory.getLog(VectorServiceImpl.class);
    private final EmbeddingService embeddingServiceCache;
    private final EmbeddingService noEmbeddingServiceCache;
    protected VectorStoreService vectorStore;
    private List<VectorStorePlugin> vectorStorePluginList = new ArrayList<VectorStorePlugin>(16);
    private List<EmbeddingCompleteListener> embeddingCompleteListenerList = new ArrayList<EmbeddingCompleteListener>(16);
    private static final int BATCH_SPLIT_SIZE = 16;

    public List<VectorStorePlugin> getVectorStorePluginList() {
        return this.vectorStorePluginList;
    }

    public void setVectorStorePlugin(List<VectorStorePlugin> vectorStorePluginList) {
        if (vectorStorePluginList != null) {
            this.vectorStorePluginList = vectorStorePluginList;
            this.vectorStore.setVectorStorePlugin(vectorStorePluginList);
        }
    }

    public void addEmbeddingCompleteListener(EmbeddingCompleteListener embeddingCompleteListener) {
        this.embeddingCompleteListenerList.add(embeddingCompleteListener);
    }

    public VectorServiceImpl(EmbeddingModel embeddingModel) {
        this.noEmbeddingServiceCache = new NoEmbeddingServiceImpl(embeddingModel);
        this.embeddingServiceCache = EmbeddingFactory.create(embeddingModel);
        this.vectorStore = VectorStoreFactory.create(embeddingModel);
        this.vectorStorePluginList.add(new KnowledgeVectorStorePlugin());
        this.vectorStore.setVectorStorePlugin(this.vectorStorePluginList);
    }

    public List<VectorResult> search(List<String> entityIds, String content, int topK) {
        return this.search(entityIds, null, null, null, content, topK);
    }

    public List<VectorResult> search(List<String> entityIds, List<Long> knowledgeIds, List<Long> chunkIds, String content, int topK) {
        return this.search(entityIds, null, null, knowledgeIds, chunkIds, content, topK);
    }

    public List<VectorResult> search(List<String> entityIds, List<String> chunkTypes, List<Long> knowledgeIds, List<Long> chunkIds, String content, int topK) {
        return this.search(entityIds, chunkTypes, null, knowledgeIds, chunkIds, content, topK);
    }

    public List<VectorResult> searchByVector(VectorQuery query) {
        List<VectorResult> results = this.vectorStore.search(query);
        ArrayList<VectorResult> vectorResults = new ArrayList(16);
        if (!this.vectorStorePluginList.isEmpty()) {
            for (VectorStorePlugin vectorStorePlugin : this.vectorStorePluginList) {
                vectorResults.addAll(vectorStorePlugin.postSearch(results));
            }
        } else {
            vectorResults = results;
        }
        return vectorResults;
    }

    public List<VectorResult> search(List<String> entityIds, List<String> chunkTypes, List<Long> groupIds, List<Long> knowledgeIds, List<Long> chunkIds, String content, int topK) {
        EmbeddingService embeddingService = this.getEmbeddingService(entityIds);
        EmbeddingVector queryVector = embeddingService.embed(content);
        VectorQuery query = new VectorQuery();
        query.setRepositoryIds(entityIds);
        query.setQueryVector(queryVector.getVector());
        if (CollectionUtils.isNotEmpty(chunkTypes)) {
            query.setChunkTypes(chunkTypes);
        }
        if (CollectionUtils.isNotEmpty(groupIds)) {
            query.setGroupIds(groupIds);
        }
        if (CollectionUtils.isNotEmpty(knowledgeIds)) {
            query.setKnowledgeIds(knowledgeIds);
        }
        if (CollectionUtils.isNotEmpty(chunkIds)) {
            query.setChunkIds(chunkIds);
        }
        query.setTopK(topK);
        query.setQueryText(content);
        return this.searchByVector(query);
    }

    public void save(String entityId, Long knowledgeId, Long chunkId, String content) {
        Chunk vectorChunk = new Chunk();
        vectorChunk.setEntityId(entityId);
        vectorChunk.setId(chunkId);
        vectorChunk.setKnowledgeId(knowledgeId);
        vectorChunk.setContent(content);
        vectorChunk.setGroupId(Long.valueOf(0L));
        this.save(vectorChunk);
    }

    public void save(Chunk chunk) {
        EmbeddingService embeddingService = this.getEmbeddingService(Collections.singletonList(chunk.getEntityId()));
        EmbeddingVector embed = embeddingService.embed(chunk.getContent());
        VectorChunk vectorChunk = VectorChunk.asVector(chunk, embed.getVector());
        List<VectorChunk> vectorChunkList = new ArrayList<VectorChunk>(16);
        vectorChunkList.add(vectorChunk);
        for (VectorStorePlugin vectorStorePlugin : this.vectorStorePluginList) {
            vectorChunkList = vectorStorePlugin.preStore(vectorChunkList);
        }
        if (!vectorChunkList.isEmpty()) {
            this.vectorStore.store((VectorChunk)vectorChunkList.get(0));
        }
    }

    protected EmbeddingService getEmbeddingService(List<String> entityIDs) {
        for (VectorStorePlugin vectorStorePlugin : this.vectorStorePluginList) {
            for (String entityID : entityIDs) {
                SplitConfig splitConfig = vectorStorePlugin.getSplitConfig(entityID);
                boolean enabled = splitConfig.isEnableVector();
                if (!enabled) continue;
                return this.embeddingServiceCache;
            }
        }
        return this.noEmbeddingServiceCache;
    }

    public String submitSaveTask(List<Chunk> chunks, Consumer<List<Chunk>> resultCallBack) {
        List<String> entityIDs = chunks.stream().map(Chunk::getEntityId).collect(Collectors.toList());
        VectorTaskManager vectorTaskManager = new VectorTaskManager(this.getEmbeddingService(entityIDs), this.vectorStore, this.vectorStorePluginList);
        return vectorTaskManager.submit(chunks, resultCallBack);
    }

    public VectorTask getSaveTask(String taskId) {
        return VectorTaskManager.getVectorTask(taskId);
    }

    public void delete(String entityId, List<Long> chunkIds) {
        this.vectorStore.delete(entityId, chunkIds);
    }

    public void deleteAll(String entityId) {
        this.vectorStore.deleteAll(entityId);
    }

    public boolean batchSave(String entityId, List<Chunk> chunks) {
        List<VectorChunk> vectorChunkList = this.batchEmbed(chunks);
        for (VectorStorePlugin vectorStorePlugin : this.vectorStorePluginList) {
            vectorChunkList = vectorStorePlugin.preStore(vectorChunkList);
        }
        Pair<List<VectorChunk>, List<VectorChunk>> result = this.vectorStore.doBatchStore(vectorChunkList);
        return ((List)result.getValue()).isEmpty();
    }

    private int getBatchSplitSize() {
        return SystemPropertyUtils.getInteger(RequestContext.get().getTenantId(), "VECTOR_BATCH_SPLIT_SIZE", 16);
    }

    private List<List<Chunk>> spitChunks(List<Chunk> chunks, int size) {
        ArrayList<List<Chunk>> result = new ArrayList<List<Chunk>>(16);
        for (int i = 0; i < chunks.size(); i += size) {
            List<Chunk> subList = chunks.subList(i, Math.min(i + size, chunks.size()));
            result.add(subList);
        }
        return result;
    }

    private List<VectorChunk> batchEmbed(List<Chunk> chunks) {
        logger.info("Starting batch embedding for {} chunks", (Object)chunks.size());
        List<List<Chunk>> lists = this.spitChunks(chunks, this.getBatchSplitSize());
        ArrayList<Future> resultFutures = new ArrayList<Future>(16);
        for (List<Chunk> list : lists) {
            resultFutures.add(threadPools.submit(() -> {
                logger.info("Starting batch embedding in thread for {} chunks", (Object)list.size());
                EmbeddingService embeddingService = this.getEmbeddingService(Collections.singletonList(((Chunk)list.get(0)).getEntityId()));
                ArrayList<VectorChunk> vectorChunks = new ArrayList<VectorChunk>(chunks.size());
                List<EmbeddingVector> embeddingVectors = embeddingService.batchEmbed(list.stream().map(Chunk::getContent).collect(Collectors.toList()));
                logger.info("Batch embedding completed, processing {} vectors", (Object)embeddingVectors.size());
                if (this.embeddingCompleteListenerList != null) {
                    for (EmbeddingCompleteListener embeddingCompleteListener : this.embeddingCompleteListenerList) {
                        embeddingCompleteListener.onEmbeddingComplete(embeddingVectors.size());
                    }
                }
                for (int i = 0; i < embeddingVectors.size(); ++i) {
                    VectorChunk vectorChunk = VectorChunk.asVector((Chunk)list.get(i), embeddingVectors.get(i).getVector());
                    vectorChunks.add(vectorChunk);
                }
                return vectorChunks;
            }));
        }
        ArrayList<VectorChunk> vectorChunks = new ArrayList<VectorChunk>(chunks.size());
        try {
            for (Future listFuture : resultFutures) {
                vectorChunks.addAll((Collection)listFuture.get());
            }
        }
        catch (InterruptedException | ExecutionException exception) {
            logger.error("Error while waiting for batch embedding to complete", (Throwable)exception);
            throw new RuntimeException(this.getCause(exception));
        }
        return vectorChunks;
    }

    private Throwable getCause(Throwable e) {
        while (e.getCause() != null) {
            e = e.getCause();
        }
        return e;
    }

    static class NoEmbeddingServiceImpl
    implements EmbeddingService {
        private EmbeddingModel embeddingModel;

        public NoEmbeddingServiceImpl(EmbeddingModel embeddingModel) {
            this.embeddingModel = embeddingModel;
        }

        @Override
        public EmbeddingVector embed(String text) throws EmbeddingException {
            EmbeddingVector embeddingVector = new EmbeddingVector();
            embeddingVector.setDimension(this.embeddingModel.getDimension());
            ArrayList<Float> list = new ArrayList<Float>(this.embeddingModel.getDimension());
            for (int i = 0; i < this.embeddingModel.getDimension(); ++i) {
                list.add(Float.valueOf(0.1f));
            }
            embeddingVector.setVector(list);
            return embeddingVector;
        }

        @Override
        public List<EmbeddingVector> batchEmbed(List<String> texts) throws EmbeddingException {
            ArrayList<EmbeddingVector> list = new ArrayList<EmbeddingVector>(texts.size());
            for (String text : texts) {
                EmbeddingVector embeddingVector = this.embed(text);
                list.add(embeddingVector);
            }
            return list;
        }

        @Override
        public String getModelType() {
            return "";
        }

        @Override
        public int dimension() {
            return 0;
        }

        @Override
        public boolean supportBatch() {
            return false;
        }
    }
}

