/*
 * Decompiled with CFR 0.152.
 */
package kd.ai.gai.core.rag.milvus;

import com.alibaba.fastjson.JSON;
import io.milvus.client.MilvusServiceClient;
import io.milvus.grpc.DataType;
import io.milvus.grpc.SearchResults;
import io.milvus.param.IndexType;
import io.milvus.param.MetricType;
import io.milvus.param.R;
import io.milvus.param.collection.CreateCollectionParam;
import io.milvus.param.collection.DropCollectionParam;
import io.milvus.param.collection.FieldType;
import io.milvus.param.collection.HasCollectionParam;
import io.milvus.param.collection.LoadCollectionParam;
import io.milvus.param.dml.DeleteParam;
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.param.index.CreateIndexParam;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import kd.ai.gai.core.domain.dto.Chunk;
import kd.ai.gai.core.rag.milvus.IMilvusDao;
import kd.ai.gai.core.rag.milvus.MilvusClientFactory;
import kd.bos.exception.KDBizException;
import kd.bos.logging.Log;
import kd.bos.logging.LogFactory;

public abstract class MilvusBaseDao<T>
implements IMilvusDao {
    private static Log log = LogFactory.getLog(MilvusBaseDao.class);
    private static final int SEARCH_K = 5;
    private static final int MAX_TOP = 20;
    private static final String SEARCH_PARAM = "{\"nprobe\":10, \"offset\":0}";

    public void createIndex() {
        CreateIndexParam createIndexParam = CreateIndexParam.newBuilder().withCollectionName(this.getCollectionName()).withFieldName("vector").withIndexType(IndexType.IVF_FLAT).withMetricType(MetricType.L2).withExtraParam(this.getDimensionParam()).withSyncMode(Boolean.FALSE).build();
        this.doCreateIndex(createIndexParam);
    }

    private String getDimensionParam() {
        return String.format("{\"nlist\":%s}", this.getDimension());
    }

    public List<Long> searchIds(List<Float> vector, List<Long> repoIdList) {
        return this.searchIds(vector, repoIdList, 5);
    }

    public List<Long> searchIds(List<Float> vector, List<Long> repoIdList, int top) {
        SearchResults results = this.search(vector, repoIdList, top);
        if (results != null) {
            return results.getResults().getIds().getIntId().getDataList();
        }
        return null;
    }

    public SearchResults search(List<Float> vector, List<Long> repoIdList) {
        return this.search(vector, repoIdList, 5);
    }

    @Override
    public MilvusServiceClient getClient() {
        return MilvusClientFactory.getClient();
    }

    @Override
    public void init() {
        ArrayList<FieldType> fieldTypeList = new ArrayList<FieldType>();
        fieldTypeList.add(FieldType.newBuilder().withName("vector").withDataType(DataType.FloatVector).withDimension(Integer.valueOf(this.getDimension())).withAutoID(false).build());
        fieldTypeList.add(FieldType.newBuilder().withName("repoId").withDataType(DataType.Int64).build());
        if (this.doCreateCollection(fieldTypeList)) {
            log.info("collectionName:{} milvus createIndex start", (Object)this.getCollectionName());
            this.createIndex();
            log.info("collectionName:{} milvus createIndex end", (Object)this.getCollectionName());
        }
        log.info("collectionName:{} milvus doLoadCollection start", (Object)this.getCollectionName());
        this.doLoadCollection();
        log.info("collectionName:{} milvus doLoadCollection end", (Object)this.getCollectionName());
    }

    @Override
    public void doCreateIndex(CreateIndexParam requestParam) {
        MilvusServiceClient milvusServiceClient = this.getClient();
        if (milvusServiceClient == null) {
            return;
        }
        R indexResult = milvusServiceClient.createIndex(requestParam);
        log.info("{} \u7d22\u5f15 {} \u521b\u5efa\u5b8c\u6210\uff0cresult\uff1a{}", new Object[]{this.getCollectionName(), requestParam.getCollectionName(), indexResult});
    }

    @Override
    public boolean doCreateCollection(List<FieldType> fieldTypeList) {
        String message;
        log.info("milvus doCreateCollection start ");
        MilvusServiceClient milvusServiceClient = this.getClient();
        if (milvusServiceClient == null) {
            log.info("createCollection milvus\u672a\u542f\u7528");
            return false;
        }
        log.info("milvus hasCollection start ");
        R result = milvusServiceClient.hasCollection(HasCollectionParam.newBuilder().withCollectionName(this.getCollectionName()).build());
        if (result.getData() == null && ((message = result.getException().getMessage()).contains("username") || message.contains("password") || message.contains("correct"))) {
            log.info("createCollection milvus\u8d26\u53f7\u5bc6\u7801\u4e0d\u5bf9!!!!!!!");
            milvusServiceClient.close();
            return false;
        }
        log.info("milvus hasCollection end result:{} ", (Object)result);
        if (result.getData() == Boolean.TRUE) {
            log.info("{} collection \u5df2\u7ecf\u5b58\u5728\uff0c\u4e0d\u518d\u91cd\u590d\u521b\u5efa", (Object)this.getCollectionName());
            return false;
        }
        CreateCollectionParam.Builder builder = CreateCollectionParam.newBuilder().withCollectionName(this.getCollectionName()).withDescription(this.getCollectionDescription()).withShardsNum(2);
        builder.addFieldType(FieldType.newBuilder().withName("id").withDataType(DataType.Int64).withPrimaryKey(true).withAutoID(false).build());
        for (FieldType fieldType : fieldTypeList) {
            if (fieldType.getName().equals("id")) continue;
            builder.addFieldType(fieldType);
        }
        CreateCollectionParam createCollectionParam = builder.build();
        log.info("{} \u521b\u5efa \u5f00\u59cb", (Object)this.getCollectionName());
        R createResult = milvusServiceClient.createCollection(createCollectionParam);
        log.info("{} \u521b\u5efa\u5b8c\u6210\uff0cresult\uff1a{}", (Object)this.getCollectionName(), (Object)createResult);
        return createResult.getStatus() == 0;
    }

    @Override
    public void doLoadCollection() {
        MilvusServiceClient milvusServiceClient = this.getClient();
        if (milvusServiceClient == null) {
            log.error("loadCollection {} milvusServiceClient \u521b\u5efa\u5931\u8d25", (Object)this.getCollectionName());
            return;
        }
        R loadResult = milvusServiceClient.loadCollection(LoadCollectionParam.newBuilder().withCollectionName(this.getCollectionName()).build());
        log.info("loadCollection {} result {}", (Object)this.getCollectionName(), (Object)loadResult);
    }

    @Override
    public void delCollection() {
        R loadResult = this.getClient().dropCollection(DropCollectionParam.newBuilder().withCollectionName(this.getCollectionName()).build());
        if (!loadResult.getStatus().equals(0)) {
            throw new RuntimeException("\u52a0\u8f7d\u5931\u8d25\u3002\u3002\u3002");
        }
        log.info("del  Collection {} result {}", (Object)this.getCollectionName(), (Object)loadResult);
    }

    @Override
    public void deleteIndex(DeleteParam deleteParam) {
        if (this.getClient() == null) {
            return;
        }
        R delete = this.getClient().delete(deleteParam);
    }

    @Override
    public boolean insert(Chunk chunk) {
        ArrayList<Long> idList = new ArrayList<Long>();
        ArrayList<Long> repoIdList = new ArrayList<Long>();
        ArrayList<List<Float>> vectorListList = new ArrayList<List<Float>>();
        if (chunk == null || chunk.getChunk().isEmpty()) {
            throw new KDBizException("\u5206\u6bb5\u4e3a\u7a7a\uff0c\u53ef\u80fd\u662f\u7a7a\u6587\u4ef6");
        }
        idList.add(chunk.getId());
        repoIdList.add(chunk.getRepositoryId());
        vectorListList.add(chunk.getVector());
        ArrayList<InsertParam.Field> fieldList = new ArrayList<InsertParam.Field>();
        fieldList.add(new InsertParam.Field("id", idList));
        fieldList.add(new InsertParam.Field("repoId", repoIdList));
        fieldList.add(new InsertParam.Field("vector", vectorListList));
        InsertParam insertParam = InsertParam.newBuilder().withCollectionName(this.getCollectionName()).withFields(fieldList).build();
        R insertResult = this.getClient().insert(insertParam);
        log.info("milvus batch insert {}", (Object)insertResult);
        return insertResult.getStatus() != 0;
    }

    @Override
    public void bachInsert(List<Chunk> chunkList) {
        if (chunkList == null || chunkList.isEmpty()) {
            throw new KDBizException("\u5206\u6bb5\u4e3a\u7a7a\uff0c\u53ef\u80fd\u662f\u7a7a\u6587\u4ef6");
        }
        int size = chunkList.size();
        ArrayList<Long> idList = new ArrayList<Long>(size);
        ArrayList<Long> repoIdList = new ArrayList<Long>(size);
        ArrayList<List<Float>> vectorListList = new ArrayList<List<Float>>(size);
        for (Chunk Chunk2 : chunkList) {
            idList.add(Chunk2.getId());
            repoIdList.add(Chunk2.getRepositoryId());
            vectorListList.add(Chunk2.getVector());
        }
        ArrayList<InsertParam.Field> fieldList = new ArrayList<InsertParam.Field>();
        fieldList.add(new InsertParam.Field("id", idList));
        fieldList.add(new InsertParam.Field("repoId", repoIdList));
        fieldList.add(new InsertParam.Field("vector", vectorListList));
        InsertParam insertParam = InsertParam.newBuilder().withCollectionName(this.getCollectionName()).withFields(fieldList).build();
        R insertResult = this.getClient().insert(insertParam);
        log.info("milvus batch insert {}", (Object)insertResult);
        if (insertResult.getStatus() != 0) {
            throw new KDBizException("milvus batch insert err:" + insertResult.getMessage());
        }
    }

    @Override
    public SearchResults search(List<Float> vector, List<Long> repoIdList, int top) {
        log.info("milvus search params,top:{}, repoIdList:{}", (Object)top, (Object)JSON.toJSONString(repoIdList));
        int currenTop = top == 0 ? 1 : Math.min(top, 20);
        log.info("milvus search params,exec top:{}", (Object)currenTop);
        List<String> search_output_fields = Collections.singletonList("repoId");
        List<List<Float>> search_vectors = Collections.singletonList(vector);
        StringBuilder exprBuilder = new StringBuilder();
        String or = " or ";
        for (Long repoId : repoIdList) {
            exprBuilder.append("repoId == ").append(repoId).append(or);
        }
        exprBuilder.delete(exprBuilder.length() - or.length(), exprBuilder.length());
        SearchParam searchParam = SearchParam.newBuilder().withCollectionName(this.getCollectionName()).withMetricType(MetricType.L2).withOutFields(search_output_fields).withTopK(Integer.valueOf(currenTop)).withVectors(search_vectors).withVectorFieldName("vector").withExpr(exprBuilder.toString()).withParams(SEARCH_PARAM).build();
        R respSearch = this.getClient().search(searchParam);
        if (respSearch.getStatus() == 0) {
            return (SearchResults)respSearch.getData();
        }
        log.info("milvus search result {}", (Object)respSearch);
        return null;
    }

    @Override
    public void delByIdList(List<Long> idList) {
        if (idList == null || idList.isEmpty()) {
            return;
        }
        StringBuilder expr = new StringBuilder("id");
        expr.append(" in [");
        String idStr = idList.stream().map(String::valueOf).collect(Collectors.joining(","));
        expr.append(idStr);
        expr.append(']');
        DeleteParam deleteParam = DeleteParam.newBuilder().withCollectionName(this.getCollectionName()).withExpr(expr.toString()).build();
        R r = this.getClient().delete(deleteParam);
        log.info("milvus del {} ,result {}", (Object)expr, (Object)r);
    }
}

