/*
 * Decompiled with CFR 0.152.
 */
package kd.bos.gptas.common.vectortask.processor;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import kd.bos.context.RequestContext;
import kd.bos.gptas.api.vector.Chunk;
import kd.bos.gptas.api.vector.VectorTask;
import kd.bos.gptas.api.vector.VectorTaskItem;
import kd.bos.gptas.common.VectorStorePlugin;
import kd.bos.gptas.common.embedding.model.EmbeddingVector;
import kd.bos.gptas.common.embedding.service.EmbeddingService;
import kd.bos.gptas.common.vectordb.model.VectorChunk;
import kd.bos.gptas.common.vectordb.service.VectorStoreService;
import kd.bos.gptas.common.vectortask.VectorTaskManager;
import kd.bos.gptas.common.vectortask.storage.VectorTaskStorage;
import kd.bos.gptas.utils.SystemPropertyUtils;
import kd.bos.logging.Log;
import kd.bos.logging.LogFactory;
import kd.bos.threads.ThreadPools;
import kd.bos.util.Pair;

public class VectorBatchProcessor {
    protected static Log logger = LogFactory.getLog(VectorBatchProcessor.class);
    private static final ExecutorService customThreadPool = ThreadPools.newCachedExecutorService((String)"VectorBatchProcessor", (int)20, (int)50);
    private static final int BATCH_SPLIT_SIZE = 16;
    private final VectorTaskManager vectorTaskManager;
    protected final EmbeddingService embeddingService;
    protected final VectorStoreService vectorStoreService;
    private final VectorTaskStorage vectorTaskStorage;
    private final List<VectorStorePlugin> vectorStorePlugins;

    private int getBatchSplitSize() {
        return SystemPropertyUtils.getInteger(RequestContext.get().getTenantId(), "VECTOR_BATCH_SPLIT_SIZE", 16);
    }

    public VectorBatchProcessor(VectorTaskManager vectorTaskManager, EmbeddingService embeddingService, VectorStoreService vectorStoreService, VectorTaskStorage vectorTaskStorage, List<VectorStorePlugin> vectorStorePlugins) {
        this.vectorTaskManager = vectorTaskManager;
        this.embeddingService = embeddingService;
        this.vectorStoreService = vectorStoreService;
        this.vectorTaskStorage = vectorTaskStorage;
        this.vectorStorePlugins = vectorStorePlugins;
    }

    public CompletableFuture<Void> process(String taskId, List<Chunk> chunks) {
        return CompletableFuture.runAsync(() -> {
            try {
                logger.info("Vector task {} processing started", (Object)taskId);
                this.vectorTaskStorage.updateVectorTaskStatus(taskId, VectorTask.VectorTaskStatus.PROCESSING, this.vectorTaskManager.getTaskLock());
                this.splitAndProcessChunks(taskId, chunks).join();
                this.updateFinalTaskStatus(taskId, this.vectorTaskManager.getTaskLock());
                logger.info("Vector task {} processing completed", (Object)taskId);
            }
            catch (Exception e) {
                this.updateFinalTaskStatus(taskId, this.vectorTaskManager.getTaskLock());
                logger.error("Failed to process vector task: " + taskId, (Throwable)e);
                throw e;
            }
        }, customThreadPool);
    }

    private CompletableFuture<Void> splitAndProcessChunks(String taskId, List<Chunk> chunks) {
        ArrayList<ArrayList<Chunk>> batches = new ArrayList<ArrayList<Chunk>>(chunks.size());
        int batchSplitSize = this.getBatchSplitSize();
        for (int i = 0; i < chunks.size(); i += batchSplitSize) {
            int end = Math.min(chunks.size(), i + batchSplitSize);
            batches.add(new ArrayList<Chunk>(chunks.subList(i, end)));
        }
        logger.info("Task {} split into {} batches, batch size: {}", new Object[]{taskId, batches.size(), batchSplitSize});
        List<CompletableFuture> batchFutures = batches.stream().map(batchChunks -> CompletableFuture.runAsync(() -> this.processBatchWithNotify(taskId, (List<Chunk>)batchChunks), customThreadPool)).collect(Collectors.toList());
        return CompletableFuture.allOf(batchFutures.toArray(new CompletableFuture[0])).exceptionally(throwable -> {
            logger.error("Error processing batches for task: " + taskId, throwable);
            throw new RuntimeException((Throwable)throwable);
        });
    }

    protected void processBatchWithNotify(String taskId, List<Chunk> chunks) {
        try {
            logger.info("Processing batch with notification for task {}, chunks: {}", (Object)taskId, (Object)chunks.size());
            Pair<List<Chunk>, List<Chunk>> processBatchResult = this.processBatch(taskId, chunks);
            List successChunks = (List)processBatchResult.getKey();
            List failedChunks = (List)processBatchResult.getValue();
            logger.info("Task {} batch processed - Success: {}, Failed: {}, Total: {}", new Object[]{taskId, successChunks.size(), failedChunks.size(), chunks.size()});
            this.vectorTaskStorage.updateTaskProcessed(taskId, successChunks.size(), failedChunks.size(), chunks.size(), this.vectorTaskManager.getTaskLock());
            if (!successChunks.isEmpty()) {
                this.vectorTaskStorage.updateVectorTaskItemStatus(taskId, successChunks, VectorTaskItem.VectorTaskItemStatus.SUCCESS);
                logger.info("Task {} executing callback for {} successful chunks", (Object)taskId, (Object)successChunks.size());
                if (this.vectorTaskManager.getChunkResultCallback() != null) {
                    this.vectorTaskManager.getChunkResultCallback().accept(successChunks);
                }
            }
        }
        catch (Throwable e) {
            logger.error("Failed to process batch for task: " + taskId, e);
            this.markBatchFailed(taskId, chunks, e.getMessage());
            this.vectorTaskStorage.updateVectorTaskItemStatus(taskId, chunks, VectorTaskItem.VectorTaskItemStatus.FAILED);
        }
    }

    protected Pair<List<Chunk>, List<Chunk>> processBatch(String taskId, List<Chunk> chunks) {
        this.vectorTaskStorage.updateVectorTaskItemStatus(taskId, chunks, VectorTaskItem.VectorTaskItemStatus.PROCESSING);
        try {
            List<VectorChunk> vectorChunks = this.batchEmbed(chunks);
            logger.info("Task {} vector embedding completed for {} chunks", (Object)taskId, (Object)chunks.size());
            ArrayList<VectorChunk> vectorChunkList = new ArrayList(16);
            if (this.vectorStorePlugins.isEmpty()) {
                vectorChunkList = vectorChunks;
            } else {
                for (VectorStorePlugin vectorStorePlugin : this.vectorStorePlugins) {
                    vectorChunkList.addAll(vectorStorePlugin.preStore(vectorChunks));
                }
                logger.info("Task {} vector prestore for {} chunks", (Object)taskId, (Object)vectorChunkList.size());
            }
            Pair<List<VectorChunk>, List<VectorChunk>> batchStoreResult = this.vectorStoreService.batchStore(vectorChunkList);
            logger.info("Task {} vector storage completed - Success: {}, Failed: {}", new Object[]{taskId, ((List)batchStoreResult.getKey()).size(), ((List)batchStoreResult.getValue()).size()});
            List successedChunks = ((List)batchStoreResult.getKey()).stream().map(VectorChunk::getChunk).collect(Collectors.toList());
            List failedChunks = ((List)batchStoreResult.getValue()).stream().map(VectorChunk::getChunk).collect(Collectors.toList());
            return new Pair(successedChunks, failedChunks);
        }
        catch (Throwable e) {
            logger.error("Error processing batch for task: " + taskId, e);
            this.markBatchFailed(taskId, chunks, e.getMessage());
            this.vectorTaskStorage.updateVectorTaskItemStatus(taskId, chunks, VectorTaskItem.VectorTaskItemStatus.FAILED);
            return new Pair(new ArrayList(1), chunks);
        }
    }

    private List<VectorChunk> batchEmbed(List<Chunk> chunks) {
        logger.info("Starting batch embedding for {} chunks", (Object)chunks.size());
        ArrayList<VectorChunk> vectorChunks = new ArrayList<VectorChunk>(chunks.size());
        List<EmbeddingVector> embeddingVectors = this.embeddingService.batchEmbed(chunks.stream().map(Chunk::getContent).collect(Collectors.toList()));
        logger.info("Batch embedding completed, processing {} vectors", (Object)embeddingVectors.size());
        for (int i = 0; i < embeddingVectors.size(); ++i) {
            VectorChunk vectorChunk = VectorChunk.asVector(chunks.get(i), embeddingVectors.get(i).getVector());
            vectorChunks.add(vectorChunk);
        }
        return vectorChunks;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void updateFinalTaskStatus(String taskId, Object taskThrealLock) {
        Object object = taskThrealLock;
        synchronized (object) {
            VectorTask vectorTask = this.vectorTaskStorage.getVectorTaskInfo(taskId);
            VectorTask.VectorTaskStatus finalStatus = vectorTask.getTotalChunkCount() == 0 || vectorTask.getSuccessChunkCount() == vectorTask.getTotalChunkCount() ? VectorTask.VectorTaskStatus.COMPLETED : (vectorTask.getSuccessChunkCount() == 0 ? VectorTask.VectorTaskStatus.FAILED : VectorTask.VectorTaskStatus.PARTIALLY_FAILED);
            logger.info("Task {} final status: {}, Total: {}, Success: {}, Failed: {}", new Object[]{taskId, finalStatus, vectorTask.getTotalChunkCount(), vectorTask.getSuccessChunkCount(), vectorTask.getFailedChunkCount()});
            this.vectorTaskStorage.updateVectorTaskStatus(taskId, finalStatus, this.vectorTaskManager.getTaskLock());
        }
    }

    protected void markBatchFailed(String taskId, List<Chunk> chunks, String errorMsg) {
        String traceId = RequestContext.get().getTraceId();
        logger.error("Task {} batch failed - Chunks: {}, TraceId: {}, Error: {}", new Object[]{taskId, chunks.size(), traceId, errorMsg});
        this.vectorTaskStorage.updateVectorTaskItemError(taskId, chunks, traceId);
    }
}

