/*
 * Decompiled with CFR 0.152.
 */
package kd.bos.gptas.embedding;

import io.milvus.client.MilvusServiceClient;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.grpc.QueryResults;
import io.milvus.param.R;
import io.milvus.param.dml.QueryParam;
import io.milvus.response.QueryResultsWrapper;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import kd.bos.data.BusinessDataReader;
import kd.bos.dataentity.entity.DynamicObject;
import kd.bos.dataentity.metadata.dynamicobject.DynamicObjectType;
import kd.bos.dataentity.serialization.SerializationUtils;
import kd.bos.entity.EntityMetadataCache;
import kd.bos.entity.MainEntityType;
import kd.bos.exception.ErrorCode;
import kd.bos.exception.KDBizException;
import kd.bos.gptas.embedding.EmbeddingService;
import kd.bos.gptas.embedding.wrapper.LLMWrapper;
import kd.bos.logging.Log;
import kd.bos.logging.LogFactory;
import kd.bos.util.ExceptionUtils;

public class EmbeddingServiceImpl
implements EmbeddingService {
    private static final Log logger = LogFactory.getLog(EmbeddingServiceImpl.class);
    private final LLMWrapper llmWrapper = new LLMWrapper();

    @Override
    public List<Map<String, Object>> search(List<Long> repoIdList, String searchText) {
        String className = "";
        try {
            if (repoIdList == null || repoIdList.isEmpty()) {
                throw new RuntimeException("repId is empty.");
            }
            Long repoId = repoIdList.get(0);
            Object embeddingService = this.getEmbeddingService(this.getLLMByRepo(repoId));
            className = embeddingService.getClass().getSimpleName();
            Method searchMethod = embeddingService.getClass().getMethod("search", List.class, String.class);
            Object result = searchMethod.invoke(embeddingService, repoIdList, searchText);
            logger.info(String.format("invoke EmbeddingService class: %s , method:search end", className));
            return (List)SerializationUtils.fromJsonString((String)SerializationUtils.toJsonString((Object)result), List.class);
        }
        catch (Exception e) {
            logger.error(String.format("Invoke class: %s , method: search Error! msg:%s, stack:%s", className, e.getMessage(), ExceptionUtils.getExceptionStackTraceMessage((Exception)e)));
            if (e instanceof InvocationTargetException) {
                InvocationTargetException ite = (InvocationTargetException)e;
                Throwable t = ite.getTargetException();
                String msg = t.getMessage();
                if (msg == null) {
                    msg = t.getClass().getSimpleName();
                }
                throw new KDBizException(t, new ErrorCode("EmbeddingServiceImpl.search", msg), new Object[]{msg});
            }
            throw new KDBizException((Throwable)e, new ErrorCode("EmbeddingServiceImpl.search", e.getMessage()), new Object[]{e.getMessage()});
        }
    }

    private Object getEmbeddingService(Object llm) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException, ClassNotFoundException {
        Class<?> serviceClass = Class.forName("kd.ai.gai.core.service.EmbeddingServiceFactory");
        Method getExecutorMethod = serviceClass.getMethod("getExecutor", this.llmWrapper.getLLMClaz());
        return getExecutorMethod.invoke(null, llm);
    }

    private Object getLLMByRepo(Long repoId) throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
        MainEntityType repoType = EntityMetadataCache.getDataEntityType((String)"gai_repo_info");
        DynamicObject repo = BusinessDataReader.loadSingle((Object)repoId, (DynamicObjectType)repoType);
        return this.llmWrapper.getLLM(repo.getString("index_method"));
    }

    private Object getMilvuService(Object llm) throws ClassNotFoundException, NoSuchMethodException, InvocationTargetException, IllegalAccessException {
        Class<?> serviceClass = Class.forName("kd.ai.gai.core.service.milvus.MilvusService");
        Method getExecutorMethod = serviceClass.getMethod("getExecutor", this.llmWrapper.getLLMClaz());
        return getExecutorMethod.invoke(null, llm);
    }

    private MilvusServiceClient getMilvusServiceClient() throws ClassNotFoundException, NoSuchMethodException, InvocationTargetException, IllegalAccessException {
        Class<?> claz = Class.forName("kd.ai.gai.core.service.milvus.MilvusClientFactory");
        Method getClientMethod = claz.getMethod("getClient", new Class[0]);
        return (MilvusServiceClient)getClientMethod.invoke(null, new Object[0]);
    }

    @Override
    public List<Map<String, Object>> reSearch(List<Long> repoIdList, List<Long> chunkIdList) {
        String className = "";
        try {
            if (repoIdList == null || repoIdList.isEmpty()) {
                throw new RuntimeException("repId is empty.");
            }
            Object llm = this.getLLMByRepo(repoIdList.get(0));
            Object instance = this.getEmbeddingService(this.getLLMByRepo(repoIdList.get(0)));
            className = instance.getClass().getSimpleName();
            Map<Long, List<Float>> vector = this.getVector(llm, chunkIdList);
            Method searchMethod = instance.getClass().getMethod("search", List.class, List.class);
            ArrayList<Map<String, Object>> listResult = new ArrayList<Map<String, Object>>(16);
            for (Map.Entry<Long, List<Float>> kv : vector.entrySet()) {
                Object result = searchMethod.invoke(instance, repoIdList, kv.getValue());
                listResult.addAll((Collection)SerializationUtils.fromJsonString((String)SerializationUtils.toJsonString((Object)result), List.class));
            }
            logger.info(String.format("invoke EmbeddingService class: %s , method:search end", className));
            return listResult;
        }
        catch (Exception e) {
            logger.error(String.format("Invoke class: %s , method: search Error! msg:%s, stack:%s", className, e.getMessage(), ExceptionUtils.getExceptionStackTraceMessage((Exception)e)));
            if (e instanceof InvocationTargetException) {
                InvocationTargetException ite = (InvocationTargetException)e;
                Throwable t = ite.getTargetException();
                String msg = t.getMessage();
                if (msg == null) {
                    msg = t.getClass().getSimpleName();
                }
                throw new KDBizException(t, new ErrorCode("EmbeddingServiceImpl.reSearch", msg), new Object[]{msg});
            }
            throw new KDBizException((Throwable)e, new ErrorCode("EmbeddingServiceImpl.reSearch", e.getMessage()), new Object[]{e.getMessage()});
        }
    }

    private Map<Long, List<Float>> getVector(Object llm, List<Long> chunkIds) throws NoSuchFieldException, ClassNotFoundException, IllegalAccessException, InvocationTargetException, NoSuchMethodException {
        MilvusServiceClient client = this.getMilvusServiceClient();
        Object dao = this.getMilvuService(llm);
        String collectionName = (String)dao.getClass().getMethod("getCollectionName", new Class[0]).invoke(dao, new Object[0]);
        ArrayList<String> list = new ArrayList<String>();
        list.add("id");
        list.add("vector");
        QueryParam queryParam = QueryParam.newBuilder().withCollectionName(collectionName).withOutFields(list).withConsistencyLevel(ConsistencyLevelEnum.STRONG).withExpr("id in [1914429023124045829]").withOutFields(list).withOffset(Long.valueOf(0L)).withLimit(Long.valueOf(10L)).build();
        R queryResults = client.query(queryParam);
        if (queryResults.getStatus().intValue() != R.Status.Success.getCode()) {
            logger.info(queryResults.getMessage());
        }
        QueryResultsWrapper queryResultsWrapper = new QueryResultsWrapper((QueryResults)queryResults.getData());
        List ids = queryResultsWrapper.getFieldWrapper("id").getFieldData();
        List vectors = queryResultsWrapper.getFieldWrapper("vector").getFieldData();
        HashMap<Long, List<Float>> results = new HashMap<Long, List<Float>>(16);
        for (int i = 0; i < ids.size(); ++i) {
            Long id = (Long)ids.get(i);
            logger.info("vectors:" + Arrays.deepToString(((List)vectors.get(i)).toArray()));
            results.put(id, (List<Float>)vectors.get(i));
        }
        return results;
    }
}

