/*
 * Decompiled with CFR 0.152.
 */
package kd.bos.flydb.core.interpreter.algox;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import kd.bos.algo.Input;
import kd.bos.algo.Output;
import kd.bos.algox.AlgoX;
import kd.bos.algox.DataSetX;
import kd.bos.algox.GroupReduceFunction;
import kd.bos.algox.Grouper;
import kd.bos.algox.JobSession;
import kd.bos.algox.JoinDataSetX;
import kd.bos.algox.MapFunction;
import kd.bos.dataentity.metadata.IDataEntityType;
import kd.bos.flydb.common.AlgoXOption;
import kd.bos.flydb.common.ServerOption;
import kd.bos.flydb.common.exception.ErrorCode;
import kd.bos.flydb.common.exception.Exceptions;
import kd.bos.flydb.core.Context;
import kd.bos.flydb.core.Contexts;
import kd.bos.flydb.core.interpreter.BindNodeCompiler;
import kd.bos.flydb.core.interpreter.Executor;
import kd.bos.flydb.core.interpreter.ScalarEvaluationCompiler;
import kd.bos.flydb.core.interpreter.algox.AggregateFunction;
import kd.bos.flydb.core.interpreter.algox.CompilerHelp;
import kd.bos.flydb.core.interpreter.algox.DataSetOutput;
import kd.bos.flydb.core.interpreter.algox.FilterFunction;
import kd.bos.flydb.core.interpreter.algox.LimitOffsetFunction;
import kd.bos.flydb.core.interpreter.algox.OneRowInput;
import kd.bos.flydb.core.interpreter.algox.ProjectFunction;
import kd.bos.flydb.core.interpreter.algox.RenameFunction;
import kd.bos.flydb.core.interpreter.algox.TableScanInput;
import kd.bos.flydb.core.interpreter.bind.BindableAggregate;
import kd.bos.flydb.core.interpreter.bind.BindableFilter;
import kd.bos.flydb.core.interpreter.bind.BindableJoin;
import kd.bos.flydb.core.interpreter.bind.BindableNode;
import kd.bos.flydb.core.interpreter.bind.BindableProject;
import kd.bos.flydb.core.interpreter.bind.BindableSort;
import kd.bos.flydb.core.interpreter.bind.BindableTableScan;
import kd.bos.flydb.core.interpreter.bind.BindableUnion;
import kd.bos.flydb.core.interpreter.scalar.ScalarEvaluation;
import kd.bos.flydb.core.rel.Aggregate;
import kd.bos.flydb.core.rel.RelNode;
import kd.bos.flydb.core.rel.Sort;
import kd.bos.flydb.core.rex.RexCall;
import kd.bos.flydb.core.rex.RexInputRef;
import kd.bos.flydb.core.rex.RexLiteral;
import kd.bos.flydb.core.rex.RexNode;
import kd.bos.flydb.core.schema.FormAttribute;
import kd.bos.flydb.core.schema.Scanner;
import kd.bos.flydb.core.schema.cosmic.CosmicEntityTable;
import kd.bos.flydb.core.schema.cosmic.CosmicFormAttribute;
import kd.bos.flydb.core.schema.cosmic.IDataEntityTypeProvider;
import kd.bos.flydb.core.sql.tree.SqlJoinType;
import kd.bos.flydb.core.sql.tree.SqlKind;
import kd.bos.flydb.core.sql.type.DataType;
import kd.bos.flydb.core.sql.type.DataTypeField;
import kd.bos.trace.TraceSpan;
import kd.bos.trace.Tracer;
import kd.bos.trace.util.TraceIdUtil;
import kd.bos.xdb.XDBConfig;
import kd.bos.xdb.sharding.config.MainTableConfig;

public class AlgoXBindNodeCompiler
implements BindNodeCompiler {
    private final IdentityHashMap<BindableNode, DataSetX> sourceMap = new IdentityHashMap();
    private JobSession jobSession;
    private ScalarEvaluationCompiler scalarEvaluationCompiler;
    private DataSetX root;
    private Context context;

    private void initJobSession() {
        String jobName = this.context.getConfig(AlgoXOption.JobName.key());
        String traceId = TraceIdUtil.getCurrentTraceIdString();
        if (traceId != null && !traceId.trim().isEmpty()) {
            jobName = jobName + '_' + traceId;
        }
        JobSession jobSession = AlgoX.createSession((String)jobName);
        String region = this.context.getConfig(ServerOption.AlgoXRegion.key());
        jobSession.getContext().setRegion(region);
        this.jobSession = jobSession;
    }

    @Override
    public Executor compile(BindableNode node) {
        if (this.context == null) {
            this.context = Contexts.get();
        }
        if (this.jobSession == null) {
            this.initJobSession();
        }
        if (this.scalarEvaluationCompiler == null) {
            this.scalarEvaluationCompiler = new ScalarEvaluationCompiler(this.context);
        }
        this.convert(node);
        Objects.requireNonNull(this.root);
        List<String> subqueryDataSetIdList = this.scalarEvaluationCompiler.getSubQueryResults();
        DataSetOutput output = new DataSetOutput(CompilerHelp.convertRowType(node.getRowType()), subqueryDataSetIdList);
        this.root.output((Output)output);
        return () -> {
            try (TraceSpan span = Tracer.create((String)"flydb", (String)"executeQuery");){
                this.jobSession.commit(Integer.parseInt(this.context.getConfig(ServerOption.QueryTimeout.key())), TimeUnit.SECONDS);
                Executor.QueryResult queryResult = new Executor.QueryResult(output.getRowMeta(), output.getId());
                return queryResult;
            }
        };
    }

    private void convert(BindableNode node) {
        if (!node.getInputList().isEmpty()) {
            for (RelNode relNode : node.getInputList()) {
                this.convert((BindableNode)relNode);
            }
        }
        if (node instanceof BindableTableScan) {
            this.root = this.convertTableScan(node.cast(BindableTableScan.class));
        } else if (node instanceof BindableSort) {
            this.root = this.convertSort(node.cast(BindableSort.class));
        } else if (node instanceof BindableProject) {
            this.root = this.convertProject(node.cast(BindableProject.class));
        } else if (node instanceof BindableJoin) {
            this.root = this.convertJoin(node.cast(BindableJoin.class));
        } else if (node instanceof BindableFilter) {
            this.root = this.convertFilter(node.cast(BindableFilter.class));
        } else if (node instanceof BindableAggregate) {
            this.root = this.convertAggregate1(node.cast(BindableAggregate.class));
        } else if (node instanceof BindableUnion) {
            this.root = this.convertUnion(node.cast(BindableUnion.class));
        }
    }

    private DataSetX convertTableScan(BindableTableScan tableScan) {
        IDataEntityType load;
        if (Boolean.parseBoolean(this.getStringFromConfig(ServerOption.EnableShardingTableInput)) && XDBConfig.isXDBEnabled() && tableScan.table instanceof CosmicEntityTable && (load = IDataEntityTypeProvider.get().load(tableScan.table.getName())) != null) {
            FormAttribute tableFormAttribute;
            while (load.getParent() != null) {
                load = load.getParent();
            }
            String name = load.getName();
            MainTableConfig configByEntity = XDBConfig.getShardingConfigProvider().getConfigByEntity(name);
            if (configByEntity != null && !configByEntity.getOptions().getIndexRoute().getAllArchiveRoutes().isEmpty()) {
                throw Exceptions.of((ErrorCode)ErrorCode.UnsupportedArchiveFeature, (Object[])new Object[]{name});
            }
            if (configByEntity != null && configByEntity.isEnabled() && (tableFormAttribute = tableScan.table.getFormAttribute()) instanceof CosmicFormAttribute) {
                IDataEntityType entityType = ((CosmicFormAttribute)tableFormAttribute).getEntityType();
                Scanner[] scanners = tableScan.table.createScanners(tableScan.index, tableScan.filter, entityType);
                if (scanners.length == 0) {
                    throw Exceptions.of((ErrorCode)ErrorCode.Unexpected1, (Object[])new Object[]{"scanners is empty"});
                }
                DataSetX input = this.jobSession.fromInput((Input[])Arrays.asList(scanners).stream().map(scanner -> new TableScanInput((Scanner)scanner)).collect(Collectors.toList()).toArray(new TableScanInput[0]));
                this.sourceMap.put(tableScan, input);
                return input;
            }
        }
        DataSetX input = this.jobSession.fromInput((Input)new TableScanInput(tableScan.table.createScanner(tableScan.index, tableScan.filter)));
        this.sourceMap.put(tableScan, input);
        return input;
    }

    private String getStringFromConfig(ServerOption serverOption) {
        String value = Contexts.get().getConfig(serverOption.key());
        return value != null ? value : serverOption.defaultValue();
    }

    private DataSetX convertSort(BindableSort sort) {
        DataSetX source;
        BindableNode input = sort.getInput(0).cast(BindableNode.class);
        DataSetX node = source = this.sourceMap.get(input);
        if (!sort.sortItemList.isEmpty()) {
            boolean hasCallOrderBy = false;
            ArrayList<String> orderByList = new ArrayList<String>(sort.sortItemList.size());
            for (Sort.SortItem sortItem : sort.sortItemList) {
                if (!(sortItem.expression instanceof RexInputRef)) {
                    hasCallOrderBy = true;
                }
                RexInputRef inputRef = sortItem.expression.cast(RexInputRef.class);
                String orderDescription = "%s %s";
                String name = source.getRowMeta().getFieldName(inputRef.getIndex());
                String ordering = sortItem.ordering.name().toLowerCase(Locale.getDefault());
                orderDescription = String.format(orderDescription, name, ordering);
                orderByList.add(orderDescription);
            }
            if (hasCallOrderBy) {
                throw Exceptions.of((ErrorCode)ErrorCode.OrderWithComplexExpression, (Object[])new Object[0]);
            }
            node = source.orderBy(new String[]{String.join((CharSequence)",", orderByList)});
        }
        if (sort.limit != null || sort.offset != null) {
            Integer offset = null;
            Integer limit = null;
            if (sort.offset != null) {
                offset = (Integer)sort.offset.cast(RexLiteral.class).getValue();
            }
            if (sort.limit != null) {
                limit = (Integer)sort.limit.cast(RexLiteral.class).getValue();
            }
            node = node.reduceGroup((GroupReduceFunction)new LimitOffsetFunction(node.getRowMeta(), offset, limit));
            node.setSingleParallel(true);
        }
        this.sourceMap.put(sort, node);
        return node;
    }

    private DataSetX convertProject(BindableProject project) {
        if (project.getInput(0) == null) {
            return this.convertOneRowScalarExpressionValue(project);
        }
        BindableNode input = project.getInput(0).cast(BindableNode.class);
        DataSetX source = this.sourceMap.get(input);
        ArrayList<ScalarEvaluation> evaluationList = new ArrayList<ScalarEvaluation>(project.exprList.size());
        for (RexNode item : project.exprList) {
            ScalarEvaluation evaluation = this.scalarEvaluationCompiler.compile(item);
            evaluation.setContext(this.context);
            evaluationList.add(evaluation);
        }
        DataSetX node = source.map((MapFunction)new ProjectFunction(project.getRowType(), evaluationList));
        this.sourceMap.put(project, node);
        return node;
    }

    private DataSetX convertUnion(BindableUnion union) {
        BindableNode input = union.getInput(0).cast(BindableNode.class);
        BindableNode input2 = union.getInput(1).cast(BindableNode.class);
        DataSetX source = this.sourceMap.get(input);
        DataSetX source2 = this.sourceMap.get(input2);
        DataSetX node = source.union(source2);
        if (!union.all) {
            String[] distinctFields = new String[union.getRowType().getFieldCount()];
            for (int i = 0; i < union.getRowType().getFieldCount(); ++i) {
                distinctFields[i] = union.getRowType().getField(i).getName();
            }
            node = node.distinct(distinctFields);
        }
        this.sourceMap.put(union, node);
        return node;
    }

    private DataSetX convertOneRowScalarExpressionValue(BindableProject project) {
        ArrayList<ScalarEvaluation> evaluationList = new ArrayList<ScalarEvaluation>(project.exprList.size());
        for (RexNode item : project.exprList) {
            ScalarEvaluation evaluation = this.scalarEvaluationCompiler.compile(item);
            evaluation.setContext(this.context);
            evaluationList.add(evaluation);
        }
        Object[] data = new Object[evaluationList.size()];
        for (int i = 0; i < evaluationList.size(); ++i) {
            data[i] = ((ScalarEvaluation)evaluationList.get(i)).eval(new Object[0]);
        }
        DataSetX node = this.jobSession.fromInput((Input)new OneRowInput(CompilerHelp.convertRowType(project.getRowType()), data));
        this.sourceMap.put(project, node);
        return node;
    }

    private void unzipJoinCondition(RexNode root, List<RexNode> list) {
        if (root == null) {
            throw Exceptions.of((ErrorCode)ErrorCode.UnsupportedSqlJoinCondition1, (Object[])new Object[0]);
        }
        if (root.getKind() == SqlKind.OR) {
            throw Exceptions.of((ErrorCode)ErrorCode.OnlySupportedAndOperatorEquiJoin, (Object[])new Object[0]);
        }
        if (root.getKind() == SqlKind.AND) {
            RexCall call = root.cast(RexCall.class);
            this.unzipJoinCondition(call.getOperand(0), list);
            this.unzipJoinCondition(call.getOperand(1), list);
        } else if (root instanceof RexCall) {
            RexCall condition = root.cast(RexCall.class);
            if (condition.getOperator().getKind() != SqlKind.EQUALS) {
                throw Exceptions.of((ErrorCode)ErrorCode.OnlySupportedAndOperatorEquiJoin, (Object[])new Object[0]);
            }
            if (!(condition.getOperand(0) instanceof RexInputRef) || !(condition.getOperand(1) instanceof RexInputRef)) {
                throw Exceptions.of((ErrorCode)ErrorCode.OnlySupportedAndOperatorEquiJoin, (Object[])new Object[0]);
            }
            list.add(condition);
        } else {
            throw Exceptions.of((ErrorCode)ErrorCode.UnsupportedSqlJoinCondition, (Object[])new Object[]{root.toString()});
        }
    }

    private DataSetX convertJoin(BindableJoin join) {
        JoinDataSetX joinDataSetX;
        BindableNode leftInput = join.getInput(0).cast(BindableNode.class);
        BindableNode rightInput = join.getInput(1).cast(BindableNode.class);
        DataSetX leftSource = this.sourceMap.get(leftInput);
        DataSetX rightSource = this.sourceMap.get(rightInput);
        HashSet<String> leftNameList = new HashSet<String>(leftInput.getRowType().getFieldCount() + rightInput.getRowType().getFieldCount());
        int leftCount = leftInput.getRowType().getFieldCount();
        leftNameList.addAll(Arrays.asList(leftSource.getRowMeta().getFieldNames()));
        String[] newRightName = new String[rightSource.getRowMeta().getFieldCount()];
        boolean duplicated = false;
        for (int i = 0; i < rightSource.getRowMeta().getFieldNames().length; ++i) {
            String name = rightSource.getRowMeta().getFieldName(i);
            if (leftNameList.contains(name)) {
                duplicated = true;
                name = name + '$' + leftCount + i;
            }
            newRightName[i] = name;
        }
        if (duplicated) {
            rightSource = rightSource.map((MapFunction)new RenameFunction(rightSource.getRowMeta(), newRightName));
        }
        ArrayList<RexNode> onConditions = new ArrayList<RexNode>();
        this.unzipJoinCondition(join.condition, onConditions);
        if (onConditions.isEmpty()) {
            throw Exceptions.of((ErrorCode)ErrorCode.UnsupportedSqlJoinCondition1, (Object[])new Object[0]);
        }
        if (join.joinType == SqlJoinType.CROSS) {
            throw Exceptions.of((ErrorCode)ErrorCode.UnsupportedFeature, (Object[])new Object[]{"CROSS JOIN"});
        }
        switch (join.joinType) {
            case INNER: {
                joinDataSetX = leftSource.join(rightSource);
                break;
            }
            case LEFT: {
                joinDataSetX = leftSource.leftJoin(rightSource);
                break;
            }
            case RIGHT: {
                joinDataSetX = leftSource.rightJoin(rightSource);
                break;
            }
            case FULL: {
                joinDataSetX = leftSource.fullJoin(rightSource);
                break;
            }
            default: {
                throw Exceptions.of((ErrorCode)ErrorCode.UnsupportedKeyword, (Object[])new Object[]{join.joinType});
            }
        }
        for (RexNode onCondition : onConditions) {
            RexCall condition = onCondition.cast(RexCall.class);
            int l = condition.getOperand(0).cast(RexInputRef.class).getIndex();
            int r = condition.getOperand(1).cast(RexInputRef.class).getIndex();
            if (l < leftCount && r < leftCount) {
                throw Exceptions.of((ErrorCode)ErrorCode.UnsupportedSqlJoinCondition, (Object[])new Object[]{join.condition.toString()});
            }
            if (l >= leftCount && r >= leftCount) {
                throw Exceptions.of((ErrorCode)ErrorCode.UnsupportedSqlJoinCondition, (Object[])new Object[]{join.condition.toString()});
            }
            if (l > r) {
                int l1 = l;
                l = r;
                r = l1;
            }
            int leftIndex = l;
            int rightIndex = r - leftSource.getRowMeta().getFieldCount();
            joinDataSetX = joinDataSetX.on(CompilerHelp.getFieldName(leftSource, leftIndex), CompilerHelp.getFieldName(rightSource, rightIndex));
        }
        this.sourceMap.put(join, (DataSetX)joinDataSetX);
        return joinDataSetX;
    }

    private DataSetX convertFilter(BindableFilter filter) {
        BindableNode input = filter.getInput(0).cast(BindableNode.class);
        DataSetX source = this.sourceMap.get(input);
        ScalarEvaluation condition = this.scalarEvaluationCompiler.compile(filter.condition);
        condition.setContext(this.context);
        DataSetX node = source.filter((kd.bos.algox.FilterFunction)new FilterFunction(condition));
        this.sourceMap.put(filter, node);
        return node;
    }

    private DataSetX convertAggregate1(BindableAggregate aggregate) {
        BindableNode input = aggregate.getInput(0).cast(BindableNode.class);
        DataSetX source = this.sourceMap.get(input);
        DataType rowType = aggregate.getRowType();
        List<DataTypeField> fieldList = rowType.getFieldList();
        String[] groupKeys = new String[aggregate.groupList.size()];
        for (int i = 0; i < aggregate.groupList.size(); ++i) {
            groupKeys[i] = fieldList.get((Integer)aggregate.groupList.get(i)).getName();
        }
        DataSetX out = source;
        Grouper grouper = null;
        if (!aggregate.groupList.isEmpty()) {
            grouper = source.groupBy(groupKeys);
        }
        ArrayList<AggregateFunction.AggCall> aggCalls = new ArrayList<AggregateFunction.AggCall>(aggregate.aggCallList.size());
        for (Aggregate.AggCall call : aggregate.aggCallList) {
            aggCalls.add(new AggregateFunction.AggCall(call.operator.getKind(), call.index, call.distinct, call.ignoreNull, call.type));
        }
        out = grouper == null ? out.reduceGroup((GroupReduceFunction)new AggregateFunction(CompilerHelp.convertRowType(rowType), aggregate.groupList, aggCalls)) : grouper.reduceGroup((GroupReduceFunction)new AggregateFunction(CompilerHelp.convertRowType(rowType), aggregate.groupList, aggCalls));
        this.sourceMap.put(aggregate, out);
        return out;
    }
}

