package kd.ai.gai.core.rag.milvus;

import com.alibaba.fastjson.JSON;
import io.milvus.client.MilvusServiceClient;
import io.milvus.grpc.DataType;
import io.milvus.grpc.SearchResults;
import io.milvus.param.IndexType;
import io.milvus.param.MetricType;
import io.milvus.param.R;
import io.milvus.param.collection.CreateCollectionParam;
import io.milvus.param.collection.DropCollectionParam;
import io.milvus.param.collection.FieldType;
import io.milvus.param.collection.HasCollectionParam;
import io.milvus.param.collection.LoadCollectionParam;
import io.milvus.param.dml.DeleteParam;
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.param.index.CreateIndexParam;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import kd.ai.gai.core.Constant;
import kd.ai.gai.core.domain.dto.Chunk;
import kd.bos.exception.KDBizException;
import kd.bos.logging.Log;
import kd.bos.logging.LogFactory;

/* loaded from: input_file:kd/ai/gai/core/rag/milvus/MilvusBaseDao.class */
public abstract class MilvusBaseDao<T> implements IMilvusDao {
    private static Log log = LogFactory.getLog(MilvusBaseDao.class);
    private static final int SEARCH_K = 5;
    private static final int MAX_TOP = 20;
    private static final String SEARCH_PARAM = "{\"nprobe\":10, \"offset\":0}";

    public void createIndex() {
        doCreateIndex(CreateIndexParam.newBuilder().withCollectionName(getCollectionName()).withFieldName(Constant.RepoVector.F_VECTOR).withIndexType(IndexType.IVF_FLAT).withMetricType(MetricType.L2).withExtraParam(getDimensionParam()).withSyncMode(Boolean.FALSE).build());
    }

    private String getDimensionParam() {
        return String.format("{\"nlist\":%s}", Integer.valueOf(getDimension()));
    }

    public List<Long> searchIds(List<Float> list, List<Long> list2) {
        return searchIds(list, list2, SEARCH_K);
    }

    public List<Long> searchIds(List<Float> list, List<Long> list2, int i) {
        SearchResults search = search(list, list2, i);
        if (search != null) {
            return search.getResults().getIds().getIntId().getDataList();
        }
        return null;
    }

    public SearchResults search(List<Float> list, List<Long> list2) {
        return search(list, list2, SEARCH_K);
    }

    @Override // kd.ai.gai.core.rag.milvus.IMilvusDao
    public MilvusServiceClient getClient() {
        return MilvusClientFactory.getClient();
    }

    @Override // kd.ai.gai.core.rag.milvus.IMilvusDao
    public void init() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(FieldType.newBuilder().withName(Constant.RepoVector.F_VECTOR).withDataType(DataType.FloatVector).withDimension(Integer.valueOf(getDimension())).withAutoID(false).build());
        arrayList.add(FieldType.newBuilder().withName(Constant.RepoVector.F_REPOD_ID).withDataType(DataType.Int64).build());
        if (doCreateCollection(arrayList)) {
            log.info("collectionName:{} milvus createIndex start", getCollectionName());
            createIndex();
            log.info("collectionName:{} milvus createIndex end", getCollectionName());
        }
        log.info("collectionName:{} milvus doLoadCollection start", getCollectionName());
        doLoadCollection();
        log.info("collectionName:{} milvus doLoadCollection end", getCollectionName());
    }

    @Override // kd.ai.gai.core.rag.milvus.IMilvusDao
    public void doCreateIndex(CreateIndexParam createIndexParam) {
        MilvusServiceClient client = getClient();
        if (client == null) {
            return;
        }
        log.info("{} 索引 {} 创建完成，result：{}", new Object[]{getCollectionName(), createIndexParam.getCollectionName(), client.createIndex(createIndexParam)});
    }

    @Override // kd.ai.gai.core.rag.milvus.IMilvusDao
    public boolean doCreateCollection(List<FieldType> list) {
        log.info("milvus doCreateCollection start ");
        MilvusServiceClient client = getClient();
        if (client == null) {
            log.info("createCollection milvus未启用");
            return false;
        }
        log.info("milvus hasCollection start ");
        R hasCollection = client.hasCollection(HasCollectionParam.newBuilder().withCollectionName(getCollectionName()).build());
        if (hasCollection.getData() == null) {
            String message = hasCollection.getException().getMessage();
            if (message.contains("username") || message.contains("password") || message.contains("correct")) {
                log.info("createCollection milvus账号密码不对!!!!!!!");
                client.close();
                return false;
            }
        }
        log.info("milvus hasCollection end result:{} ", hasCollection);
        if (hasCollection.getData() == Boolean.TRUE) {
            log.info("{} collection 已经存在，不再重复创建", getCollectionName());
            return false;
        }
        CreateCollectionParam.Builder withShardsNum = CreateCollectionParam.newBuilder().withCollectionName(getCollectionName()).withDescription(getCollectionDescription()).withShardsNum(2);
        withShardsNum.addFieldType(FieldType.newBuilder().withName("id").withDataType(DataType.Int64).withPrimaryKey(true).withAutoID(false).build());
        for (FieldType fieldType : list) {
            if (!fieldType.getName().equals("id")) {
                withShardsNum.addFieldType(fieldType);
            }
        }
        CreateCollectionParam build = withShardsNum.build();
        log.info("{} 创建 开始", getCollectionName());
        R createCollection = client.createCollection(build);
        log.info("{} 创建完成，result：{}", getCollectionName(), createCollection);
        return createCollection.getStatus().intValue() == 0;
    }

    @Override // kd.ai.gai.core.rag.milvus.IMilvusDao
    public void doLoadCollection() {
        MilvusServiceClient client = getClient();
        if (client == null) {
            log.error("loadCollection {} milvusServiceClient 创建失败", getCollectionName());
        } else {
            log.info("loadCollection {} result {}", getCollectionName(), client.loadCollection(LoadCollectionParam.newBuilder().withCollectionName(getCollectionName()).build()));
        }
    }

    @Override // kd.ai.gai.core.rag.milvus.IMilvusDao
    public void delCollection() {
        R dropCollection = getClient().dropCollection(DropCollectionParam.newBuilder().withCollectionName(getCollectionName()).build());
        if (!dropCollection.getStatus().equals(0)) {
            throw new RuntimeException("加载失败。。。");
        }
        log.info("del  Collection {} result {}", getCollectionName(), dropCollection);
    }

    @Override // kd.ai.gai.core.rag.milvus.IMilvusDao
    public void deleteIndex(DeleteParam deleteParam) {
        if (getClient() == null) {
            return;
        }
        getClient().delete(deleteParam);
    }

    @Override // kd.ai.gai.core.rag.milvus.IMilvusDao
    public boolean insert(Chunk chunk) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        if (chunk == null || chunk.getChunk().isEmpty()) {
            throw new KDBizException("分段为空，可能是空文件");
        }
        arrayList.add(Long.valueOf(chunk.getId()));
        arrayList2.add(Long.valueOf(chunk.getRepositoryId()));
        arrayList3.add(chunk.getVector());
        ArrayList arrayList4 = new ArrayList();
        arrayList4.add(new InsertParam.Field("id", arrayList));
        arrayList4.add(new InsertParam.Field(Constant.RepoVector.F_REPOD_ID, arrayList2));
        arrayList4.add(new InsertParam.Field(Constant.RepoVector.F_VECTOR, arrayList3));
        R insert = getClient().insert(InsertParam.newBuilder().withCollectionName(getCollectionName()).withFields(arrayList4).build());
        log.info("milvus batch insert {}", insert);
        return insert.getStatus().intValue() != 0;
    }

    @Override // kd.ai.gai.core.rag.milvus.IMilvusDao
    public void bachInsert(List<Chunk> list) {
        if (list == null || list.isEmpty()) {
            throw new KDBizException("分段为空，可能是空文件");
        }
        int size = list.size();
        ArrayList arrayList = new ArrayList(size);
        ArrayList arrayList2 = new ArrayList(size);
        ArrayList arrayList3 = new ArrayList(size);
        for (Chunk chunk : list) {
            arrayList.add(Long.valueOf(chunk.getId()));
            arrayList2.add(Long.valueOf(chunk.getRepositoryId()));
            arrayList3.add(chunk.getVector());
        }
        ArrayList arrayList4 = new ArrayList();
        arrayList4.add(new InsertParam.Field("id", arrayList));
        arrayList4.add(new InsertParam.Field(Constant.RepoVector.F_REPOD_ID, arrayList2));
        arrayList4.add(new InsertParam.Field(Constant.RepoVector.F_VECTOR, arrayList3));
        R insert = getClient().insert(InsertParam.newBuilder().withCollectionName(getCollectionName()).withFields(arrayList4).build());
        log.info("milvus batch insert {}", insert);
        if (insert.getStatus().intValue() != 0) {
            throw new KDBizException("milvus batch insert err:" + insert.getMessage());
        }
    }

    @Override // kd.ai.gai.core.rag.milvus.IMilvusDao
    public SearchResults search(List<Float> list, List<Long> list2, int i) {
        log.info("milvus search params,top:{}, repoIdList:{}", Integer.valueOf(i), JSON.toJSONString(list2));
        int min = i == 0 ? 1 : Math.min(i, 20);
        log.info("milvus search params,exec top:{}", Integer.valueOf(min));
        List singletonList = Collections.singletonList(Constant.RepoVector.F_REPOD_ID);
        List singletonList2 = Collections.singletonList(list);
        StringBuilder sb = new StringBuilder();
        Iterator<Long> it = list2.iterator();
        while (it.hasNext()) {
            sb.append("repoId == ").append(it.next()).append(" or ");
        }
        sb.delete(sb.length() - " or ".length(), sb.length());
        R search = getClient().search(SearchParam.newBuilder().withCollectionName(getCollectionName()).withMetricType(MetricType.L2).withOutFields(singletonList).withTopK(Integer.valueOf(min)).withVectors(singletonList2).withVectorFieldName(Constant.RepoVector.F_VECTOR).withExpr(sb.toString()).withParams(SEARCH_PARAM).build());
        if (search.getStatus().intValue() == 0) {
            return (SearchResults) search.getData();
        }
        log.info("milvus search result {}", search);
        return null;
    }

    @Override // kd.ai.gai.core.rag.milvus.IMilvusDao
    public void delByIdList(List<Long> list) {
        if (list == null || list.isEmpty()) {
            return;
        }
        StringBuilder sb = new StringBuilder("id");
        sb.append(" in [");
        sb.append((String) list.stream().map((v0) -> {
            return String.valueOf(v0);
        }).collect(Collectors.joining(",")));
        sb.append(']');
        log.info("milvus del {} ,result {}", sb, getClient().delete(DeleteParam.newBuilder().withCollectionName(getCollectionName()).withExpr(sb.toString()).build()));
    }
}
