package kd.bos.flydb.core.optimize;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import kd.bos.flydb.common.ServerConfig;
import kd.bos.flydb.common.ServerOption;
import kd.bos.flydb.common.exception.ErrorCode;
import kd.bos.flydb.common.exception.Exceptions;
import kd.bos.flydb.core.Contexts;
import kd.bos.flydb.core.interpreter.ScalarEvaluationCompiler;
import kd.bos.flydb.core.interpreter.bind.BindableAggregate;
import kd.bos.flydb.core.interpreter.bind.BindableFilter;
import kd.bos.flydb.core.interpreter.bind.BindableJoin;
import kd.bos.flydb.core.interpreter.bind.BindableNode;
import kd.bos.flydb.core.interpreter.bind.BindableProject;
import kd.bos.flydb.core.interpreter.bind.BindableSort;
import kd.bos.flydb.core.interpreter.bind.BindableTableScan;
import kd.bos.flydb.core.optimize.rbo.Planner;
import kd.bos.flydb.core.rel.Aggregate;
import kd.bos.flydb.core.rel.Filter;
import kd.bos.flydb.core.rel.Join;
import kd.bos.flydb.core.rel.Project;
import kd.bos.flydb.core.rel.RelNode;
import kd.bos.flydb.core.rel.Sort;
import kd.bos.flydb.core.rel.TableScan;
import kd.bos.flydb.core.rex.BaseRexNodeVisitor;
import kd.bos.flydb.core.rex.BaseRexNodeVisitor1;
import kd.bos.flydb.core.rex.RexCall;
import kd.bos.flydb.core.rex.RexInputRef;
import kd.bos.flydb.core.rex.RexLiteral;
import kd.bos.flydb.core.rex.RexNode;
import kd.bos.flydb.core.rex.RexNodeList;
import kd.bos.flydb.core.rex.RexSubQuery;
import kd.bos.flydb.core.sql.type.DataTypeField;
import kd.bos.flydb.core.sql.type.TupleDataType;

/* loaded from: input_file:kd/bos/flydb/core/optimize/OptimizeHelp.class */
public final class OptimizeHelp {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:kd/bos/flydb/core/optimize/OptimizeHelp$InputRefCollector.class */
    public static class InputRefCollector extends BaseRexNodeVisitor1 {
        private BitSet bitSet;

        public InputRefCollector() {
            this.bitSet = new BitSet();
        }

        public InputRefCollector(BitSet bitSet) {
            this.bitSet = new BitSet();
            this.bitSet = bitSet;
        }

        @Override // kd.bos.flydb.core.rex.BaseRexNodeVisitor1, kd.bos.flydb.core.rex.RexNodeVisitor1
        public void visitRexInputRef(RexInputRef rexInputRef) {
            this.bitSet.set(rexInputRef.getIndex());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:kd/bos/flydb/core/optimize/OptimizeHelp$RexRefMapper.class */
    public static class RexRefMapper extends BaseRexNodeVisitor<RexNode> {
        private final HashMap<Integer, Integer> mapping;

        public RexRefMapper(HashMap<Integer, Integer> hashMap) {
            this.mapping = hashMap;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // kd.bos.flydb.core.rex.BaseRexNodeVisitor
        public RexNode visitChildren(RexNode rexNode) {
            if (!(rexNode instanceof RexCall) && !(rexNode instanceof RexNodeList) && !(rexNode instanceof RexInputRef) && !(rexNode instanceof RexSubQuery)) {
                return rexNode;
            }
            return (RexNode) rexNode.accept(this);
        }

        @Override // kd.bos.flydb.core.rex.BaseRexNodeVisitor, kd.bos.flydb.core.rex.RexNodeVisitor
        public RexNode visitRexCall(RexCall rexCall) {
            List<RexNode> operands = rexCall.getOperands();
            for (int i = 0; i < operands.size(); i++) {
                RexNode rexNode = operands.get(i);
                RexNode rexNode2 = (RexNode) rexNode.accept(this);
                if (rexNode2 != rexNode) {
                    rexCall.replaceOperand(i, rexNode2);
                }
            }
            rexCall.buildDigest();
            return rexCall;
        }

        @Override // kd.bos.flydb.core.rex.BaseRexNodeVisitor, kd.bos.flydb.core.rex.RexNodeVisitor
        public RexNode visitRexInputRef(RexInputRef rexInputRef) {
            return new RexInputRef(rexInputRef.getType(), this.mapping.get(Integer.valueOf(rexInputRef.getIndex())).intValue());
        }

        @Override // kd.bos.flydb.core.rex.BaseRexNodeVisitor, kd.bos.flydb.core.rex.RexNodeVisitor
        public RexNode visitRexNodeList(RexNodeList rexNodeList) {
            if (rexNodeList == null || rexNodeList.isEmpty()) {
                return rexNodeList;
            }
            for (int i = 0; i < rexNodeList.size(); i++) {
                RexNode rexNode = rexNodeList.get(i);
                RexNode rexNode2 = (RexNode) rexNode.accept(this);
                if (rexNode != rexNode2) {
                    rexNodeList.set(i, rexNode2);
                }
            }
            return rexNodeList;
        }

        @Override // kd.bos.flydb.core.rex.BaseRexNodeVisitor, kd.bos.flydb.core.rex.RexNodeVisitor
        public RexNode visitRexSubQuery(RexSubQuery rexSubQuery) {
            BindableNode optimize0 = OptimizeHelp.optimize0(rexSubQuery.getQuery());
            if (optimize0 != rexSubQuery.getQuery()) {
                rexSubQuery.setQuery(optimize0);
            }
            rexSubQuery.buildDigest();
            return rexSubQuery;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:kd/bos/flydb/core/optimize/OptimizeHelp$TrimResult.class */
    public static class TrimResult {
        public final RelNode node;
        public final HashMap<Integer, Integer> mapper;

        public TrimResult(RelNode relNode, HashMap<Integer, Integer> hashMap) {
            this.node = relNode;
            this.mapper = hashMap;
        }
    }

    private OptimizeHelp() {
    }

    public static BindableNode optimize(RelNode relNode) {
        BindableNode optimize0 = optimize0(trim(relNode));
        if (ServerConfig.getSessionABCConfiguration(Contexts.get().getSessionId()).getBool(ServerOption.EnableRBO).booleanValue()) {
            Planner planner = new Planner(optimize0);
            planner.optimize();
            optimize0 = (BindableNode) planner.getRoot();
        }
        return optimize0;
    }

    public static BindableNode optimizeConstExpression(BindableNode bindableNode) {
        if (!bindableNode.getInputList().isEmpty()) {
            Iterator<RelNode> it = bindableNode.getInputList().iterator();
            while (it.hasNext()) {
                optimizeConstExpression((BindableNode) it.next());
            }
        }
        if (bindableNode instanceof BindableProject) {
            BindableProject bindableProject = (BindableProject) bindableNode.cast(BindableProject.class);
            for (int i = 0; i < bindableProject.exprList.size(); i++) {
                RexNode rexNode = bindableProject.exprList.get(i);
                if (rexNode.isConstExpression()) {
                    bindableProject.exprList.set(i, (RexNode) new RexLiteral(rexNode.getType(), optimizeConstExpression0(rexNode)));
                }
            }
        } else if (bindableNode instanceof BindableFilter) {
            BindableFilter bindableFilter = (BindableFilter) bindableNode.cast(BindableFilter.class);
            if (bindableFilter.condition.isConstExpression()) {
                bindableFilter.condition = new RexLiteral(bindableFilter.condition.getType(), optimizeConstExpression0(bindableFilter.condition));
            }
        }
        return bindableNode;
    }

    private static Object optimizeConstExpression0(RexNode rexNode) {
        if (rexNode instanceof RexCall) {
            return new ScalarEvaluationCompiler(Contexts.get()).compile(rexNode).eval(new Object[0]);
        }
        if (rexNode instanceof RexLiteral) {
            return ((RexLiteral) rexNode.cast(RexLiteral.class)).getValue();
        }
        throw Exceptions.of(ErrorCode.Unexpected, new Object[]{rexNode.getClass().getName()});
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static BindableNode optimize0(RelNode relNode) {
        BindableNode convert = convert(relNode);
        if (convert != null && !relNode.getInputList().isEmpty()) {
            List<RelNode> inputList = relNode.getInputList();
            ArrayList arrayList = new ArrayList(inputList.size());
            Iterator<RelNode> it = inputList.iterator();
            while (it.hasNext()) {
                arrayList.add(optimize0(it.next()));
            }
            for (int i = 0; i < arrayList.size(); i++) {
                convert.replaceInput(i, (RelNode) arrayList.get(i));
            }
        }
        return convert;
    }

    private static RelNode trim(RelNode relNode) {
        BitSet bitSet = new BitSet();
        for (int i = 0; i < relNode.getRowType().getFieldCount(); i++) {
            bitSet.set(i);
        }
        return trim0(relNode, bitSet).node;
    }

    private static TrimResult trim0(RelNode relNode, BitSet bitSet) {
        if (relNode instanceof Project) {
            Project project = (Project) relNode.cast(Project.class);
            InputRefCollector inputRefCollector = new InputRefCollector();
            for (int i = 0; i < project.exprList.size(); i++) {
                if (bitSet.get(i)) {
                    project.exprList.get(i).accept(inputRefCollector);
                }
            }
            if (project.getInputList().isEmpty()) {
                HashMap hashMap = new HashMap();
                for (int i2 = 0; i2 < project.exprList.size(); i2++) {
                    hashMap.put(Integer.valueOf(i2), Integer.valueOf(i2));
                }
                return new TrimResult(project, hashMap);
            }
            TrimResult trimSingleChild = trimSingleChild(project, inputRefCollector.bitSet);
            RexRefMapper rexRefMapper = new RexRefMapper(trimSingleChild.mapper);
            ArrayList arrayList = new ArrayList(project.exprList.size());
            ArrayList arrayList2 = new ArrayList(project.exprList.size());
            ArrayList arrayList3 = new ArrayList(project.exprList.size());
            List<DataTypeField> fieldList = project.getRowType().getFieldList();
            HashMap hashMap2 = new HashMap();
            for (int i3 = 0; i3 < project.exprList.size(); i3++) {
                if (bitSet.get(i3)) {
                    arrayList.add(project.exprList.get(i3).accept(rexRefMapper));
                    arrayList2.add(fieldList.get(i3).getName());
                    arrayList3.add(fieldList.get(i3).getType());
                    hashMap2.put(Integer.valueOf(i3), Integer.valueOf(hashMap2.size()));
                }
            }
            return new TrimResult(new Project(trimSingleChild.node, new RexNodeList(arrayList), new TupleDataType(project.getRowType().id(), arrayList2, arrayList3)), hashMap2);
        }
        if (relNode instanceof Sort) {
            Sort sort = (Sort) relNode.cast(Sort.class);
            if (sort.sortItemList.isEmpty()) {
                TrimResult trimSingleChild2 = trimSingleChild(sort, bitSet);
                return new TrimResult(new Sort(trimSingleChild2.node, sort.offset, sort.limit, Collections.emptyList()), trimSingleChild2.mapper);
            }
            InputRefCollector inputRefCollector2 = new InputRefCollector(bitSet);
            Iterator<Sort.SortItem> it = sort.sortItemList.iterator();
            while (it.hasNext()) {
                it.next().expression.accept(inputRefCollector2);
            }
            TrimResult trimSingleChild3 = trimSingleChild(sort, inputRefCollector2.bitSet);
            RexRefMapper rexRefMapper2 = new RexRefMapper(trimSingleChild3.mapper);
            ArrayList arrayList4 = new ArrayList(sort.sortItemList.size());
            for (Sort.SortItem sortItem : sort.sortItemList) {
                arrayList4.add(new Sort.SortItem(sortItem.ordering, (RexNode) sortItem.expression.accept(rexRefMapper2)));
            }
            return new TrimResult(new Sort(trimSingleChild3.node, sort.offset, sort.limit, arrayList4), trimSingleChild3.mapper);
        }
        if (relNode instanceof Filter) {
            Filter filter = (Filter) relNode.cast(Filter.class);
            if (filter.condition == null) {
                return new TrimResult(filter, trimSingleChild(filter, bitSet).mapper);
            }
            InputRefCollector inputRefCollector3 = new InputRefCollector(bitSet);
            filter.condition.accept(inputRefCollector3);
            TrimResult trimSingleChild4 = trimSingleChild(filter, inputRefCollector3.bitSet);
            return new TrimResult(new Filter(trimSingleChild4.node, (RexNode) filter.condition.accept(new RexRefMapper(trimSingleChild4.mapper))), trimSingleChild4.mapper);
        }
        if (relNode instanceof Aggregate) {
            Aggregate aggregate = (Aggregate) relNode.cast(Aggregate.class);
            int size = aggregate.groupList.size() + aggregate.aggCallList.size();
            BitSet bitSet2 = new BitSet();
            HashMap hashMap3 = new HashMap();
            for (int i4 = 0; i4 < size; i4++) {
                bitSet2.set(i4);
                hashMap3.put(Integer.valueOf(i4), Integer.valueOf(i4));
            }
            trimSingleChild(aggregate, bitSet2);
            return new TrimResult(aggregate, hashMap3);
        }
        if (relNode instanceof TableScan) {
            TableScan tableScan = (TableScan) relNode.cast(TableScan.class);
            ArrayList arrayList5 = new ArrayList(tableScan.getRowType().getFieldCount());
            HashMap hashMap4 = new HashMap(tableScan.getRowType().getFieldCount());
            List<DataTypeField> fieldList2 = tableScan.getRowType().getFieldList();
            ArrayList arrayList6 = new ArrayList(fieldList2.size());
            ArrayList arrayList7 = new ArrayList(fieldList2.size());
            if (bitSet.isEmpty()) {
                bitSet = new BitSet();
                bitSet.set(0);
            }
            for (int i5 = 0; i5 < tableScan.getRowType().getFieldCount(); i5++) {
                if (bitSet.get(i5)) {
                    hashMap4.put(Integer.valueOf(i5), Integer.valueOf(hashMap4.size()));
                    arrayList5.add(Integer.valueOf(i5));
                    DataTypeField dataTypeField = fieldList2.get(i5);
                    arrayList6.add(dataTypeField.getName());
                    arrayList7.add(dataTypeField.getType());
                }
            }
            int[] iArr = new int[arrayList5.size()];
            for (int i6 = 0; i6 < arrayList5.size(); i6++) {
                iArr[i6] = ((Integer) arrayList5.get(i6)).intValue();
            }
            return new TrimResult(new BindableTableScan(tableScan.table, iArr, null, arrayList6, arrayList7), hashMap4);
        }
        if (!(relNode instanceof Join)) {
            throw Exceptions.of(ErrorCode.Unexpected, new Object[0]);
        }
        Join join = (Join) relNode.cast(Join.class);
        InputRefCollector inputRefCollector4 = new InputRefCollector(bitSet);
        if (join.condition != null) {
            join.condition.accept(inputRefCollector4);
        }
        RelNode input = join.getInput(0);
        RelNode input2 = join.getInput(1);
        int fieldCount = input.getRowType().getFieldCount();
        int fieldCount2 = input2.getRowType().getFieldCount();
        TrimResult trimSingleChild5 = trimSingleChild(join, getRange(inputRefCollector4.bitSet, 0, fieldCount), 0);
        TrimResult trimSingleChild6 = trimSingleChild(join, getRange(inputRefCollector4.bitSet, fieldCount, fieldCount + fieldCount2), 1);
        HashMap hashMap5 = new HashMap(32);
        for (Map.Entry<Integer, Integer> entry : trimSingleChild5.mapper.entrySet()) {
            hashMap5.put(entry.getKey(), entry.getValue());
        }
        int fieldCount3 = trimSingleChild5.node.getRowType().getFieldCount();
        for (Map.Entry<Integer, Integer> entry2 : trimSingleChild6.mapper.entrySet()) {
            hashMap5.put(Integer.valueOf(entry2.getKey().intValue() + fieldCount), Integer.valueOf(entry2.getValue().intValue() + fieldCount3));
        }
        return new TrimResult(new Join(trimSingleChild5.node, trimSingleChild6.node, join.joinType, join.condition != null ? (RexNode) join.condition.accept(new RexRefMapper(hashMap5)) : null), hashMap5);
    }

    private static BitSet getRange(BitSet bitSet, int i, int i2) {
        BitSet bitSet2 = new BitSet();
        for (int i3 = i; i3 < i2; i3++) {
            if (bitSet.get(i3)) {
                bitSet2.set(i3 - i);
            }
        }
        return bitSet2;
    }

    private static TrimResult trimSingleChild(RelNode relNode, BitSet bitSet) {
        return trimSingleChild(relNode, bitSet, 0);
    }

    private static TrimResult trimSingleChild(RelNode relNode, BitSet bitSet, int i) {
        RelNode input = relNode.getInput(i);
        TrimResult trim0 = trim0(input, bitSet);
        if (input != trim0.node) {
            relNode.replaceInput(i, trim0.node);
        }
        return trim0;
    }

    private static BindableNode convert(RelNode relNode) {
        if (relNode == null) {
            return null;
        }
        if (relNode instanceof BindableNode) {
            return (BindableNode) relNode.cast(BindableNode.class);
        }
        if (relNode instanceof TableScan) {
            return convertTableScan((TableScan) relNode.cast(TableScan.class));
        }
        if (relNode instanceof Sort) {
            return convertSort((Sort) relNode.cast(Sort.class));
        }
        if (relNode instanceof Project) {
            return convertProject((Project) relNode.cast(Project.class));
        }
        if (relNode instanceof Join) {
            return convertJoin((Join) relNode.cast(Join.class));
        }
        if (relNode instanceof Filter) {
            return convertFilter((Filter) relNode.cast(Filter.class));
        }
        if (relNode instanceof Aggregate) {
            return convertAggregate((Aggregate) relNode.cast(Aggregate.class));
        }
        throw Exceptions.of(ErrorCode.Unexpected1, new Object[]{relNode.getClass().getName()});
    }

    private static BindableNode convertTableScan(TableScan tableScan) {
        List<DataTypeField> fieldList = tableScan.getRowType().getFieldList();
        int[] iArr = new int[fieldList.size()];
        for (int i = 0; i < fieldList.size(); i++) {
            iArr[i] = fieldList.get(i).getIndex();
        }
        return new BindableTableScan(tableScan.table, iArr, null, tableScan.getRowType());
    }

    private static BindableNode convertProject(Project project) {
        return new BindableProject(project.getInput(0), project.exprList, project.getRowType());
    }

    private static BindableSort convertSort(Sort sort) {
        return new BindableSort(sort.getInput(0), sort.offset, sort.limit, sort.sortItemList);
    }

    private static BindableJoin convertJoin(Join join) {
        return new BindableJoin(join.getInput(0), join.getInput(1), join.joinType, join.condition);
    }

    private static BindableFilter convertFilter(Filter filter) {
        return new BindableFilter(filter.getInput(0), filter.condition);
    }

    private static BindableAggregate convertAggregate(Aggregate aggregate) {
        return new BindableAggregate(aggregate.getInput(0), aggregate.groupList, aggregate.aggCallList, aggregate.getRowType());
    }
}
