/*
 * Decompiled with CFR 0.152.
 */
package kd.ai.gai.core.search.index.service.impl;

import com.google.protobuf.ProtocolStringList;
import io.milvus.client.MilvusServiceClient;
import io.milvus.grpc.DataType;
import io.milvus.grpc.GetCollectionStatisticsResponse;
import io.milvus.grpc.SearchResults;
import io.milvus.param.Constant;
import io.milvus.param.IndexType;
import io.milvus.param.MetricType;
import io.milvus.param.R;
import io.milvus.param.collection.CreateCollectionParam;
import io.milvus.param.collection.DropCollectionParam;
import io.milvus.param.collection.FieldType;
import io.milvus.param.collection.FlushParam;
import io.milvus.param.collection.GetCollectionStatisticsParam;
import io.milvus.param.collection.HasCollectionParam;
import io.milvus.param.collection.LoadCollectionParam;
import io.milvus.param.dml.DeleteParam;
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.param.index.CreateIndexParam;
import io.milvus.response.GetCollStatResponseWrapper;
import io.milvus.response.SearchResultsWrapper;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import kd.ai.gai.core.code.GaiErrorCode;
import kd.ai.gai.core.code.GaiExceptionUtil;
import kd.ai.gai.core.domain.llm.base.Result4Embedding;
import kd.ai.gai.core.engine.Errors;
import kd.ai.gai.core.engine.json.JsonUtil;
import kd.ai.gai.core.enuz.LLM;
import kd.ai.gai.core.enuz.VectorIndexType;
import kd.ai.gai.core.enuz.VectorMetricType;
import kd.ai.gai.core.enuz.repo.SearchDataSource;
import kd.ai.gai.core.rag.milvus.MilvusClientFactory;
import kd.ai.gai.core.search.index.base.IndexFiledType;
import kd.ai.gai.core.search.index.param.structured.StructuredSearchResultData;
import kd.ai.gai.core.search.index.param.vector.VectorCollectionCreateParam;
import kd.ai.gai.core.search.index.param.vector.VectorCollectionDropParam;
import kd.ai.gai.core.search.index.param.vector.VectorCollectionNameParam;
import kd.ai.gai.core.search.index.param.vector.VectorData;
import kd.ai.gai.core.search.index.param.vector.VectorDataAddParam;
import kd.ai.gai.core.search.index.param.vector.VectorDataCountParam;
import kd.ai.gai.core.search.index.param.vector.VectorDataDelParam;
import kd.ai.gai.core.search.index.param.vector.VectorDataSearchParam;
import kd.ai.gai.core.search.index.schema.VectorIndexFieldSchema;
import kd.ai.gai.core.search.index.schema.VectorIndexSchema;
import kd.ai.gai.core.search.index.service.VectorService;
import kd.ai.gai.core.service.embedding.EmbeddingServiceFactory;
import kd.bos.context.RequestContext;
import kd.bos.dataentity.resource.ResManager;
import kd.bos.exception.ErrorCode;
import kd.bos.exception.KDBizException;
import kd.bos.logging.Log;
import kd.bos.logging.LogFactory;
import kd.bos.session.SystemPropertyUtils;
import kd.bos.util.StringUtils;
import org.jetbrains.annotations.NotNull;

public class MilvusVectorServiceImpl
implements VectorService {
    private static Log logger = LogFactory.getLog(MilvusVectorServiceImpl.class);
    private final String PROPERTY_DATA_ID = "id";
    private final String PROPERTY_NAME = "propertyName";
    private final String PROPERTY_VECTOR = "propertyVector";

    private String getDimension(int dim) {
        return String.format("{\"nlist\":%s}", dim);
    }

    private String getNprobeParam(int dim) {
        String tenantId = RequestContext.get().getTenantId();
        String nprobeRatio = SystemPropertyUtils.getProptyByTenant((String)"milvus.search.nprobe.ratio.val", (String)tenantId);
        Double ratio = StringUtils.isEmpty((String)nprobeRatio) ? Double.valueOf(0.1) : Double.valueOf(nprobeRatio);
        int nprobe = (int)((double)dim * ratio);
        logger.info("milvus search params {}, ratio:{}, nprobe:{}", new Object[]{this.getDimension(dim), ratio, nprobe});
        return String.format("{\"nprobe\":%s, \"offset\":0}", nprobe > 1 ? nprobe : 1);
    }

    @Override
    public boolean createCollection(VectorCollectionCreateParam createParam) {
        logger.info("\u3010RAG\u3011milvus init param:\u3010{}\u3011 start", (Object)JsonUtil.toJson(createParam));
        String collectionName = createParam.getCollectionName();
        if (StringUtils.isEmpty((String)collectionName)) {
            throw new RuntimeException("milvus create collection ,collectionName is not empty");
        }
        if (this.doCreateCollection(createParam)) {
            logger.info("\u3010RAG\u3011milvus collectionName:{} init start", (Object)collectionName);
            VectorMetricType vectorMetricType = createParam.getVectorMetricType();
            List<VectorIndexFieldSchema> fields = this.getSearchFields(createParam.getDim(), vectorMetricType);
            this.createIndex(new VectorIndexSchema(collectionName, createParam.getVectorMetricType(), fields));
            logger.info("\u3010RAG\u3011milvus collectionName:{} init end", (Object)collectionName);
        }
        logger.info("\u3010RAG\u3011milvus collectionName:{} doLoadCollection start", (Object)collectionName);
        this.doLoadCollection(collectionName);
        logger.info("\u3010RAG\u3011milvus collectionName:{} doLoadCollection end", (Object)collectionName);
        return true;
    }

    private MetricType convertMetricType(VectorMetricType metricType) {
        switch (metricType) {
            case L2: {
                return MetricType.L2;
            }
            case COSINE: {
                return MetricType.COSINE;
            }
        }
        String msg = ResManager.loadKDString((String)("milvus \u5411\u91cf\u6570\u636e\u5e93\uff0c\u4e0d\u652f\u6301\u7684\u8ba1\u7b97\u7c7b\u578b:" + (Object)((Object)metricType)), (String)"gai.core.milvus.store", (String)"ai-gai-core", (Object[])new Object[0]);
        logger.error(msg);
        throw new RuntimeException(msg);
    }

    @NotNull
    private List<VectorIndexFieldSchema> getSearchFields(int dim, VectorMetricType vectorMetricType) {
        ArrayList<VectorIndexFieldSchema> fields = new ArrayList<VectorIndexFieldSchema>(3);
        VectorIndexFieldSchema propertyName = new VectorIndexFieldSchema("propertyName", IndexFiledType.VarChar);
        propertyName.setFiledMaxLength(256);
        fields.add(propertyName);
        fields.add(new VectorIndexFieldSchema("propertyVector", IndexFiledType.VECTOR, dim, vectorMetricType, VectorIndexType.IVF_FLAT));
        return fields;
    }

    private boolean doCreateCollection(VectorCollectionCreateParam createParam) {
        logger.info("\u3010RAG\u3011milvus doCreateCollection:\u3010{}\u3011\u5f00\u59cb", (Object)JsonUtil.toJson(createParam));
        MilvusServiceClient milvusServiceClient = MilvusClientFactory.getClient();
        if (milvusServiceClient == null) {
            logger.error("\u3010RAG\u3011milvus init milvus \u672a\u542f\u7528");
            throw new RuntimeException("milvus create collection ,collectionName is not empty");
        }
        String collectionName = createParam.getCollectionName();
        int dim = createParam.getDim();
        String embeddingNumer = createParam.getEmbeddingNumer();
        VectorMetricType vectorMetricType = createParam.getVectorMetricType();
        List<VectorIndexFieldSchema> fields = this.getSearchFields(dim, vectorMetricType);
        if (this.collectionNameExist(collectionName)) {
            String msg = ResManager.loadKDString((String)String.format("collection \u3010%s\u3011 \u5df2\u7ecf\u5b58\u5728\uff0c\u4e0d\u518d\u91cd\u590d\u521b\u5efa", collectionName), (String)"MilvusVectorServiceImpl_2", (String)"ai-gai-core", (Object[])new Object[0]);
            logger.info(msg);
            return false;
        }
        CreateCollectionParam.Builder builder = CreateCollectionParam.newBuilder().withCollectionName(collectionName).withDescription(String.format("%s %s Vector collection for GAI", embeddingNumer, dim)).withShardsNum(2).addFieldType(FieldType.newBuilder().withName("id").withDataType(DataType.VarChar).withMaxLength(Integer.valueOf(128)).withPrimaryKey(true).withAutoID(false).build());
        List<FieldType> fieldTypes = this.conversionFiledType(fields);
        for (FieldType fieldType : fieldTypes) {
            if (fieldType.getName().equals("id")) continue;
            builder.addFieldType(fieldType);
        }
        CreateCollectionParam createCollectionParam = builder.build();
        logger.info("\u3010RAG\u3011milvus collectionName:\u3010{}\u3011\u521b\u5efa \u5f00\u59cb", (Object)collectionName);
        R createResult = MilvusClientFactory.getClient().createCollection(createCollectionParam);
        logger.info("\u3010RAG\u3011milvus collectionName:\u3010{}\u3011\u521b\u5efa\u5b8c\u6210\uff0cresult\uff1a{}", (Object)collectionName, (Object)createResult);
        return R.Status.Success.getCode() == createResult.getStatus().intValue();
    }

    public void doLoadCollection(String collectionName) {
        MilvusServiceClient milvusServiceClient = MilvusClientFactory.getClient();
        if (milvusServiceClient == null) {
            logger.error("\u3010RAG\u3011milvus loadCollection {} milvusServiceClient \u521b\u5efa\u5931\u8d25", (Object)collectionName);
            return;
        }
        R loadResult = milvusServiceClient.loadCollection(LoadCollectionParam.newBuilder().withCollectionName(collectionName).build());
        logger.info("\u3010RAG\u3011milvus loadCollection {} result {}", (Object)collectionName, (Object)loadResult);
    }

    private List<FieldType> conversionFiledType(List<VectorIndexFieldSchema> fields) {
        ArrayList<FieldType> fieldTypeList = new ArrayList<FieldType>(5);
        for (VectorIndexFieldSchema field : fields) {
            FieldType.Builder fieldBuilder = FieldType.newBuilder().withName(field.getFieldKey());
            switch (field.getFiledType()) {
                case VECTOR: {
                    if (field.getDim() <= 0) {
                        throw new RuntimeException("milvus create collection dim the value must > 0");
                    }
                    fieldBuilder.withDataType(DataType.FloatVector).withDimension(Integer.valueOf(field.getDim())).withAutoID(false);
                    break;
                }
                case LONG: {
                    fieldBuilder.withDataType(DataType.Int64);
                    break;
                }
                case VarChar: {
                    fieldBuilder.withDataType(DataType.VarChar);
                    fieldBuilder.withMaxLength(Integer.valueOf(field.getFiledMaxLength()));
                    break;
                }
                default: {
                    throw new RuntimeException("field type not supported");
                }
            }
            fieldTypeList.add(fieldBuilder.build());
        }
        return fieldTypeList;
    }

    @Override
    public boolean collectionExist(VectorCollectionNameParam collectionNameParam) {
        logger.info("\u3010RAG\u3011milvus check hasCollection collectionName\uff1a{} start  ", (Object)JsonUtil.toJson(collectionNameParam));
        String collectionName = collectionNameParam.getCollectionName();
        return this.collectionNameExist(collectionName);
    }

    @NotNull
    private boolean collectionNameExist(String collectionName) {
        MilvusServiceClient milvusServiceClient = MilvusClientFactory.getClient();
        R result = milvusServiceClient.hasCollection(HasCollectionParam.newBuilder().withCollectionName(collectionName).build());
        logger.info("\u3010RAG\u3011milvus hasCollection end result:{} ", (Object)result);
        return (Boolean)result.getData();
    }

    @Override
    public boolean dropCollection(VectorCollectionDropParam dropParam) {
        logger.info("\u3010RAG\u3011milvus delete collection start,dropParam: {}", (Object)JsonUtil.toJson(dropParam));
        String collectionName = dropParam.getCollectionName();
        R delResult = MilvusClientFactory.getClient().dropCollection(DropCollectionParam.newBuilder().withCollectionName(collectionName).build());
        logger.info("\u3010RAG\u3011milvus delete collection {} result {}", (Object)collectionName, (Object)delResult);
        return R.Status.Success.getCode() == delResult.getStatus().intValue();
    }

    @Override
    public boolean createIndex(VectorIndexSchema indexSchema) {
        logger.info("\u3010RAG\u3011milvus \u7d22\u5f15\u521b\u5efa\u5f00\u59cb\uff0cindexSchema\uff1a{}", (Object)JsonUtil.toJson(indexSchema));
        String collectionName = indexSchema.getCollectionName();
        List<VectorIndexFieldSchema> indexFields = indexSchema.getIndexFields();
        if (indexFields == null || indexFields.isEmpty()) {
            throw new RuntimeException("milvus create collection index field is not empty");
        }
        MetricType metricType = this.convertMetricType(indexSchema.getVectorMetricType());
        CreateIndexParam.Builder builder = CreateIndexParam.newBuilder();
        for (VectorIndexFieldSchema indexField : indexFields) {
            if (IndexFiledType.VECTOR != indexField.getFiledType()) continue;
            builder.withCollectionName(collectionName).withFieldName(indexField.getFieldKey()).withIndexType(IndexType.IVF_FLAT).withMetricType(metricType).withExtraParam(this.getDimension(indexField.getDim())).withSyncMode(Boolean.FALSE);
        }
        CreateIndexParam createIndexParam = builder.build();
        R indexResult = MilvusClientFactory.getClient().createIndex(createIndexParam);
        logger.info("\u3010RAG\u3011milvus {} \u7d22\u5f15 \u521b\u5efa\u5b8c\u6210\uff0cresult\uff1a{}", (Object)collectionName, (Object)indexResult);
        return R.Status.Success.getCode() == indexResult.getStatus().intValue();
    }

    @Override
    public boolean batchAddData(VectorDataAddParam dataAddParam) {
        String collectionName = dataAddParam.getCollectionName();
        List<VectorData> indexDatas = dataAddParam.getVectorDataList();
        if (StringUtils.isEmpty((String)collectionName)) {
            throw new RuntimeException("milvus collection batch add data. collectionName is not empty");
        }
        if (indexDatas == null || indexDatas.isEmpty()) {
            throw new RuntimeException("milvus collection batch add data. indexData is not empty");
        }
        int size = indexDatas.size();
        ArrayList<String> dataIds = new ArrayList<String>(size);
        ArrayList<String> propertyNames = new ArrayList<String>(size);
        ArrayList<List<Float>> propertyVectorList = new ArrayList<List<Float>>(size);
        ArrayList<InsertParam.Field> fieldList = new ArrayList<InsertParam.Field>();
        for (VectorData indexData : indexDatas) {
            dataIds.add(indexData.getId());
            propertyNames.add(indexData.getPropertyName());
            propertyVectorList.add(indexData.getPropertyVector());
        }
        fieldList.add(new InsertParam.Field("id", dataIds));
        fieldList.add(new InsertParam.Field("propertyName", propertyNames));
        fieldList.add(new InsertParam.Field("propertyVector", propertyVectorList));
        InsertParam insertParam = InsertParam.newBuilder().withCollectionName(collectionName).withFields(fieldList).build();
        R insertResult = MilvusClientFactory.getClient().insert(insertParam);
        logger.info("\u3010RAG\u3011milvus batch add result:{}", (Object)insertResult);
        return R.Status.Success.getCode() == insertResult.getStatus().intValue();
    }

    @Override
    public boolean batchDelData(VectorDataDelParam delParam) {
        logger.info("\u3010RAG\u3011milvus batch del delParam is :{}", (Object)JsonUtil.toJson(delParam));
        String collectionName = delParam.getCollectionName();
        List<VectorDataDelParam.DelData> datas = delParam.getDelDataList();
        if (StringUtils.isEmpty((String)collectionName)) {
            throw new RuntimeException("milvus collection batch add data. collectionName is not empty");
        }
        if (datas == null || datas.isEmpty()) {
            throw new RuntimeException("milvus collection batch del data. idList is not empty");
        }
        ArrayList<String> dataIds = new ArrayList<String>(datas.size());
        for (VectorDataDelParam.DelData data : datas) {
            dataIds.add(data.getId());
        }
        String idsStr = dataIds.stream().map(id -> "\"" + id + "\"").collect(Collectors.joining(","));
        String expr = "id in [" + idsStr + "]";
        DeleteParam deleteParam = DeleteParam.newBuilder().withCollectionName(collectionName).withExpr(expr).build();
        R r = MilvusClientFactory.getClient().delete(deleteParam);
        logger.info("\u3010RAG\u3011milvus batch del {} ,result {}", (Object)expr, (Object)r);
        if (R.Status.Success.getCode() == r.getStatus().intValue()) {
            FlushParam flushParam = FlushParam.newBuilder().addCollectionName(collectionName).withSyncFlushWaitingTimeout(Constant.MAX_WAITING_FLUSHING_TIMEOUT).build();
            R flushResponseR = MilvusClientFactory.getClient().flush(flushParam);
            if (flushResponseR.getStatus().intValue() == R.Status.Success.getCode()) {
                logger.info("\u3010RAG\u3011milvus batch del data flush ok");
            } else {
                logger.error("\u3010RAG\u3011milvus batch del data flush error");
            }
        }
        return R.Status.Success.getCode() == r.getStatus().intValue();
    }

    @Override
    public List<StructuredSearchResultData> search(VectorDataSearchParam searchParam) {
        logger.info("\u3010RAG\u3011milvus search params:{}", (Object)JsonUtil.toJson(searchParam));
        String collectionName = searchParam.getCollectionName();
        String query = searchParam.getQuery();
        int topK = searchParam.getTopK();
        float similarity = searchParam.getSimilarity();
        LLM embeddingModel = searchParam.getEmbeddingModel();
        VectorMetricType vectorMetricType = searchParam.getVectorMetricType();
        int dim = EmbeddingServiceFactory.getExecutor(embeddingModel, vectorMetricType).getDimension();
        MetricType metricType = this.convertMetricType(vectorMetricType);
        List<Float> vector = this.Q2V(query, embeddingModel);
        List<String> search_output_fields = Collections.singletonList("id");
        List<List<Float>> search_vectors = Collections.singletonList(vector);
        SearchParam milvusSearchParam = SearchParam.newBuilder().withCollectionName(collectionName).withMetricType(metricType).withOutFields(search_output_fields).withTopK(Integer.valueOf(topK)).withVectors(search_vectors).withVectorFieldName("propertyVector").withParams(this.getNprobeParam(dim)).build();
        R respSearch = MilvusClientFactory.getClient().search(milvusSearchParam);
        if (respSearch.getStatus() == 0) {
            SearchResults respSearchData = (SearchResults)respSearch.getData();
            logger.info("\u3010RAG\u3011milvus search result {}", (Object)JsonUtil.toJson(respSearchData));
            ProtocolStringList ids = respSearchData.getResults().getIds().getStrId().getDataList();
            logger.info("\u3010RAG\u3011:\u68c0\u7d22\u5757IDS:{}", (Object)ids);
            if (ids != null && ids.size() > 0) {
                ArrayList<StructuredSearchResultData> datas = new ArrayList<StructuredSearchResultData>(ids.size());
                SearchResultsWrapper searchWrapper = new SearchResultsWrapper(respSearchData.getResults());
                List idScores = searchWrapper.getIDScore(0);
                int size = idScores.size();
                for (int i = 0; i < size; ++i) {
                    String id = (String)ids.get(i);
                    SearchResultsWrapper.IDScore idScore = (SearchResultsWrapper.IDScore)idScores.get(i);
                    float milvusScore = idScore.getScore();
                    if (metricType == MetricType.COSINE) {
                        milvusScore = milvusScore >= 1.0f ? 1.0f : (1.0f + milvusScore) / 2.0f;
                    }
                    datas.add(new StructuredSearchResultData(id, SearchDataSource.MILVUS, milvusScore));
                }
                return datas;
            }
            logger.warn("\u3010RAG\u3011\u672c\u6b21\u672a\u68c0\u7d22\u5230\u5757\u6570\u636e");
        }
        logger.error("\u3010RAG\u3011milvus search result {}", (Object)respSearch);
        return Collections.EMPTY_LIST;
    }

    @Override
    public long queryCountByCollectionName(VectorDataCountParam countParam) {
        logger.info("\u3010RAG\u3011milvus query total params:{}", (Object)JsonUtil.toJson(countParam));
        String collectionName = countParam.getCollectionName();
        LoadCollectionParam loadCollectionParam = LoadCollectionParam.newBuilder().withCollectionName(collectionName).withRefresh(Boolean.valueOf(true)).withSyncLoad(Boolean.valueOf(true)).build();
        R loadResult = MilvusClientFactory.getClient().loadCollection(loadCollectionParam);
        logger.info("\u3010RAG\u3011milvus query total by load collectionName :{}", (Object)JsonUtil.toJson(loadResult));
        GetCollectionStatisticsParam statisticsParam = GetCollectionStatisticsParam.newBuilder().withCollectionName(collectionName).build();
        R statisticsResponseR = MilvusClientFactory.getClient().getCollectionStatistics(statisticsParam);
        if (statisticsResponseR.getStatus().intValue() == R.Status.Success.getCode()) {
            GetCollStatResponseWrapper wrapper = new GetCollStatResponseWrapper((GetCollectionStatisticsResponse)statisticsResponseR.getData());
            long total = wrapper.getRowCount();
            logger.info("\u3010RAG\u3011milvus query total by collectionName :{}", (Object)total);
            return total;
        }
        logger.error("\u3010RAG\u3011milvus query total by collectionName error :{}", (Object)JsonUtil.toJson(statisticsResponseR));
        throw new KDBizException(GaiExceptionUtil.buildExtMsgErr(GaiErrorCode.REPO_MILVUS_ERR, statisticsResponseR.getMessage()), new Object[0]);
    }

    @NotNull
    private List<Float> Q2V(String query, LLM embeddingModel) {
        Result4Embedding result4Embedding = EmbeddingServiceFactory.getExecutor(embeddingModel, VectorMetricType.COSINE).embedding(Collections.singletonList(query));
        if (!result4Embedding.getCode().equals(Errors.OK.getCode())) {
            throw new KDBizException(new ErrorCode(result4Embedding.getCode(), result4Embedding.getErrMsg()), new Object[]{result4Embedding.getErrMsg()});
        }
        List<Float> vector = result4Embedding.getVectorList().get(0);
        if (vector == null || vector.isEmpty()) {
            logger.error(Errors.EMBEDDING_ERROR.getMessage());
            throw new KDBizException(Errors.EMBEDDING_ERROR, new Object[0]);
        }
        return vector;
    }
}

