/*
 * Decompiled with CFR 0.152.
 */
package kd.bos.flydb.core.optimize;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import kd.bos.flydb.common.ServerConfig;
import kd.bos.flydb.common.ServerOption;
import kd.bos.flydb.common.config.ABCConfiguration;
import kd.bos.flydb.common.config.Option;
import kd.bos.flydb.common.exception.ErrorCode;
import kd.bos.flydb.common.exception.Exceptions;
import kd.bos.flydb.core.Context;
import kd.bos.flydb.core.Contexts;
import kd.bos.flydb.core.interpreter.ScalarEvaluationCompiler;
import kd.bos.flydb.core.interpreter.bind.BindableAggregate;
import kd.bos.flydb.core.interpreter.bind.BindableFilter;
import kd.bos.flydb.core.interpreter.bind.BindableJoin;
import kd.bos.flydb.core.interpreter.bind.BindableNode;
import kd.bos.flydb.core.interpreter.bind.BindableProject;
import kd.bos.flydb.core.interpreter.bind.BindableSort;
import kd.bos.flydb.core.interpreter.bind.BindableTableScan;
import kd.bos.flydb.core.interpreter.bind.BindableUnion;
import kd.bos.flydb.core.interpreter.scalar.ScalarEvaluation;
import kd.bos.flydb.core.optimize.rbo.Planner;
import kd.bos.flydb.core.rel.Aggregate;
import kd.bos.flydb.core.rel.Filter;
import kd.bos.flydb.core.rel.Join;
import kd.bos.flydb.core.rel.Project;
import kd.bos.flydb.core.rel.RelNode;
import kd.bos.flydb.core.rel.Sort;
import kd.bos.flydb.core.rel.TableScan;
import kd.bos.flydb.core.rel.Union;
import kd.bos.flydb.core.rex.BaseRexNodeVisitor;
import kd.bos.flydb.core.rex.BaseRexNodeVisitor1;
import kd.bos.flydb.core.rex.RexCall;
import kd.bos.flydb.core.rex.RexInputRef;
import kd.bos.flydb.core.rex.RexLiteral;
import kd.bos.flydb.core.rex.RexNode;
import kd.bos.flydb.core.rex.RexNodeList;
import kd.bos.flydb.core.rex.RexSubQuery;
import kd.bos.flydb.core.sql.type.DataType;
import kd.bos.flydb.core.sql.type.DataTypeField;
import kd.bos.flydb.core.sql.type.TupleDataType;

public final class OptimizeHelp {
    private OptimizeHelp() {
    }

    public static BindableNode optimize(RelNode relNode) {
        relNode = OptimizeHelp.trim(relNode);
        BindableNode bindableNode = OptimizeHelp.optimize0(relNode);
        Context context = Contexts.get();
        ABCConfiguration configuration = ServerConfig.getSessionABCConfiguration((String)context.getSessionId());
        if (configuration.getBool((Option)ServerOption.EnableRBO).booleanValue()) {
            Planner planner = new Planner(bindableNode);
            planner.optimize();
            bindableNode = (BindableNode)planner.getRoot();
        }
        return bindableNode;
    }

    public static BindableNode optimizeConstExpression(BindableNode bindableNode) {
        if (!bindableNode.getInputList().isEmpty()) {
            for (RelNode relNode : bindableNode.getInputList()) {
                OptimizeHelp.optimizeConstExpression((BindableNode)relNode);
            }
        }
        if (bindableNode instanceof BindableProject) {
            BindableProject bindableProject = bindableNode.cast(BindableProject.class);
            for (int i = 0; i < bindableProject.exprList.size(); ++i) {
                RexNode rexNode = bindableProject.exprList.get(i);
                if (!rexNode.isConstExpression()) continue;
                Object value = OptimizeHelp.optimizeConstExpression0(rexNode);
                bindableProject.exprList.set(i, new RexLiteral(rexNode.getType(), value));
            }
        } else if (bindableNode instanceof BindableFilter) {
            BindableFilter bindableFilter = bindableNode.cast(BindableFilter.class);
            if (bindableFilter.condition.isConstExpression()) {
                Object value = OptimizeHelp.optimizeConstExpression0(bindableFilter.condition);
                bindableFilter.condition = new RexLiteral(bindableFilter.condition.getType(), value);
            }
        }
        return bindableNode;
    }

    private static Object optimizeConstExpression0(RexNode rexNode) {
        if (rexNode instanceof RexCall) {
            ScalarEvaluationCompiler scalarEvaluationCompiler = new ScalarEvaluationCompiler(Contexts.get());
            ScalarEvaluation compile = scalarEvaluationCompiler.compile(rexNode);
            return compile.eval(new Object[0]);
        }
        if (rexNode instanceof RexLiteral) {
            RexLiteral rexLiteral = rexNode.cast(RexLiteral.class);
            return rexLiteral.getValue();
        }
        throw Exceptions.of((ErrorCode)ErrorCode.Unexpected, (Object[])new Object[]{rexNode.getClass().getName()});
    }

    private static BindableNode optimize0(RelNode relNode) {
        BindableNode bindableNode = OptimizeHelp.convert(relNode);
        if (bindableNode != null && !relNode.getInputList().isEmpty()) {
            List<RelNode> inputList = relNode.getInputList();
            ArrayList<BindableNode> bindableNodeList = new ArrayList<BindableNode>(inputList.size());
            for (RelNode value : inputList) {
                BindableNode node = OptimizeHelp.optimize0(value);
                bindableNodeList.add(node);
            }
            for (int i = 0; i < bindableNodeList.size(); ++i) {
                bindableNode.replaceInput(i, (RelNode)bindableNodeList.get(i));
            }
        }
        return bindableNode;
    }

    private static RelNode trim(RelNode root) {
        BitSet input = new BitSet();
        for (int i = 0; i < root.getRowType().getFieldCount(); ++i) {
            input.set(i);
        }
        return OptimizeHelp.trim0((RelNode)root, (BitSet)input).node;
    }

    /*
     * WARNING - void declaration
     */
    private static TrimResult trim0(RelNode node, BitSet input) {
        if (node instanceof Union) {
            Union union = node.cast(Union.class);
            TrimResult trimResult = OptimizeHelp.trimSingleChild(union, input, 0);
            OptimizeHelp.trimSingleChild(union, input, 1);
            return new TrimResult(new Union(union.getInputList(), union.getRowType(), union.isAll()), trimResult.mapper);
        }
        if (node instanceof Project) {
            Project project = node.cast(Project.class);
            InputRefCollector collector = new InputRefCollector();
            for (int i = 0; i < project.exprList.size(); ++i) {
                if (!input.get(i)) continue;
                RexNode expr = project.exprList.get(i);
                expr.accept(collector);
            }
            if (project.getInputList().isEmpty()) {
                HashMap<Integer, Integer> newMapping = new HashMap<Integer, Integer>();
                for (int i = 0; i < project.exprList.size(); ++i) {
                    newMapping.put(i, i);
                }
                return new TrimResult(project, newMapping);
            }
            TrimResult trimResult = OptimizeHelp.trimSingleChild(project, collector.bitSet);
            RexRefMapper mapper = new RexRefMapper(trimResult.mapper);
            ArrayList<RexNode> newExprList = new ArrayList<RexNode>(project.exprList.size());
            ArrayList<String> newNameList = new ArrayList<String>(project.exprList.size());
            ArrayList<DataType> newTypeList = new ArrayList<DataType>(project.exprList.size());
            List<DataTypeField> sourceFieldList = project.getRowType().getFieldList();
            HashMap<Integer, Integer> newMapping = new HashMap<Integer, Integer>();
            for (int i = 0; i < project.exprList.size(); ++i) {
                if (!input.get(i)) continue;
                newExprList.add(project.exprList.get(i).accept(mapper));
                newNameList.add(sourceFieldList.get(i).getName());
                newTypeList.add(sourceFieldList.get(i).getType());
                newMapping.put(i, newMapping.size());
            }
            return new TrimResult(new Project(trimResult.node, new RexNodeList(newExprList), new TupleDataType(project.getRowType().id(), newNameList, newTypeList)), newMapping);
        }
        if (node instanceof Sort) {
            Sort sort = node.cast(Sort.class);
            if (sort.sortItemList.isEmpty()) {
                TrimResult trimResult = OptimizeHelp.trimSingleChild(sort, input);
                return new TrimResult(new Sort(trimResult.node, sort.offset, sort.limit, Collections.emptyList()), trimResult.mapper);
            }
            InputRefCollector collector = new InputRefCollector(input);
            for (Sort.SortItem sortItem : sort.sortItemList) {
                sortItem.expression.accept(collector);
            }
            TrimResult trimResult = OptimizeHelp.trimSingleChild(sort, collector.bitSet);
            RexRefMapper mapper = new RexRefMapper(trimResult.mapper);
            ArrayList<Sort.SortItem> newSortList = new ArrayList<Sort.SortItem>(sort.sortItemList.size());
            for (Sort.SortItem item1 : sort.sortItemList) {
                RexNode expr = item1.expression.accept(mapper);
                Sort.SortItem item2 = new Sort.SortItem(item1.ordering, expr);
                newSortList.add(item2);
            }
            return new TrimResult(new Sort(trimResult.node, sort.offset, sort.limit, newSortList), trimResult.mapper);
        }
        if (node instanceof Filter) {
            Filter filter = node.cast(Filter.class);
            if (filter.condition == null) {
                TrimResult trimResult = OptimizeHelp.trimSingleChild(filter, input);
                return new TrimResult(filter, trimResult.mapper);
            }
            InputRefCollector collector = new InputRefCollector(input);
            filter.condition.accept(collector);
            TrimResult trimResult = OptimizeHelp.trimSingleChild(filter, collector.bitSet);
            RexRefMapper mapper = new RexRefMapper(trimResult.mapper);
            RexNode condition2 = filter.condition.accept(mapper);
            return new TrimResult(new Filter(trimResult.node, condition2), trimResult.mapper);
        }
        if (node instanceof Aggregate) {
            Aggregate aggregate = node.cast(Aggregate.class);
            int count = aggregate.groupList.size() + aggregate.aggCallList.size();
            BitSet bitSet = new BitSet();
            HashMap<Integer, Integer> newMapping = new HashMap<Integer, Integer>();
            for (int i = 0; i < count; ++i) {
                bitSet.set(i);
                newMapping.put(i, i);
            }
            OptimizeHelp.trimSingleChild(aggregate, bitSet);
            return new TrimResult(aggregate, newMapping);
        }
        if (node instanceof TableScan) {
            TableScan tableScan = node.cast(TableScan.class);
            ArrayList<Integer> projectList = new ArrayList<Integer>(tableScan.getRowType().getFieldCount());
            HashMap<Integer, Integer> map = new HashMap<Integer, Integer>(tableScan.getRowType().getFieldCount());
            List<DataTypeField> fieldList = tableScan.getRowType().getFieldList();
            ArrayList<String> fieldNameList = new ArrayList<String>(fieldList.size());
            ArrayList<DataType> fieldTypeList = new ArrayList<DataType>(fieldList.size());
            if (input.isEmpty()) {
                input = new BitSet();
                input.set(0);
            }
            for (int i = 0; i < tableScan.getRowType().getFieldCount(); ++i) {
                if (!input.get(i)) continue;
                map.put(i, map.size());
                projectList.add(i);
                DataTypeField field = fieldList.get(i);
                fieldNameList.add(field.getName());
                fieldTypeList.add(field.getType());
            }
            int[] project = new int[projectList.size()];
            for (int i = 0; i < projectList.size(); ++i) {
                project[i] = (Integer)projectList.get(i);
            }
            BindableTableScan newTableScan = new BindableTableScan(tableScan.table, project, null, fieldNameList, fieldTypeList);
            return new TrimResult(newTableScan, map);
        }
        if (node instanceof Join) {
            void var15_68;
            Join join = node.cast(Join.class);
            InputRefCollector collector = new InputRefCollector(input);
            if (join.condition != null) {
                join.condition.accept(collector);
            }
            RelNode left = join.getInput(0);
            RelNode right = join.getInput(1);
            int leftCount = left.getRowType().getFieldCount();
            int rightCount = right.getRowType().getFieldCount();
            BitSet leftUsing = OptimizeHelp.getRange(collector.bitSet, 0, leftCount);
            TrimResult leftTrimResult = OptimizeHelp.trimSingleChild(join, leftUsing, 0);
            BitSet rightUsing = OptimizeHelp.getRange(collector.bitSet, leftCount, leftCount + rightCount);
            TrimResult rightTrimResult = OptimizeHelp.trimSingleChild(join, rightUsing, 1);
            HashMap<Integer, Integer> mapper = new HashMap<Integer, Integer>(32);
            for (Map.Entry<Integer, Integer> entry : leftTrimResult.mapper.entrySet()) {
                mapper.put(entry.getKey(), entry.getValue());
            }
            int newLeftCount = leftTrimResult.node.getRowType().getFieldCount();
            for (Map.Entry<Integer, Integer> entry : rightTrimResult.mapper.entrySet()) {
                mapper.put(entry.getKey() + leftCount, entry.getValue() + newLeftCount);
            }
            RexRefMapper rexRefMapper = new RexRefMapper(mapper);
            Object var15_66 = null;
            if (join.condition != null) {
                RexNode rexNode = join.condition.accept(rexRefMapper);
            }
            Join join2 = new Join(leftTrimResult.node, rightTrimResult.node, join.joinType, (RexNode)var15_68);
            return new TrimResult(join2, mapper);
        }
        throw Exceptions.of((ErrorCode)ErrorCode.Unexpected, (Object[])new Object[0]);
    }

    private static BitSet getRange(BitSet bitSet, int start, int end) {
        BitSet result = new BitSet();
        for (int i = start; i < end; ++i) {
            if (!bitSet.get(i)) continue;
            result.set(i - start);
        }
        return result;
    }

    private static TrimResult trimSingleChild(RelNode node, BitSet inputBitSet) {
        return OptimizeHelp.trimSingleChild(node, inputBitSet, 0);
    }

    private static TrimResult trimSingleChild(RelNode node, BitSet inputBitSet, int i) {
        RelNode node1 = node.getInput(i);
        TrimResult trimResult = OptimizeHelp.trim0(node1, inputBitSet);
        if (node1 != trimResult.node) {
            node.replaceInput(i, trimResult.node);
        }
        return trimResult;
    }

    private static BindableNode convert(RelNode relNode) {
        if (relNode == null) {
            return null;
        }
        if (relNode instanceof BindableNode) {
            return relNode.cast(BindableNode.class);
        }
        if (relNode instanceof TableScan) {
            return OptimizeHelp.convertTableScan(relNode.cast(TableScan.class));
        }
        if (relNode instanceof Sort) {
            return OptimizeHelp.convertSort(relNode.cast(Sort.class));
        }
        if (relNode instanceof Project) {
            return OptimizeHelp.convertProject(relNode.cast(Project.class));
        }
        if (relNode instanceof Join) {
            return OptimizeHelp.convertJoin(relNode.cast(Join.class));
        }
        if (relNode instanceof Filter) {
            return OptimizeHelp.convertFilter(relNode.cast(Filter.class));
        }
        if (relNode instanceof Aggregate) {
            return OptimizeHelp.convertAggregate(relNode.cast(Aggregate.class));
        }
        if (relNode instanceof Union) {
            return OptimizeHelp.convertUnion(relNode.cast(Union.class));
        }
        throw Exceptions.of((ErrorCode)ErrorCode.Unexpected1, (Object[])new Object[]{relNode.getClass().getName()});
    }

    private static BindableNode convertUnion(Union union) {
        return new BindableUnion(union.getInputList(), union.getRowType(), union.all);
    }

    private static BindableNode convertTableScan(TableScan tableScan) {
        List<DataTypeField> fieldList = tableScan.getRowType().getFieldList();
        int[] project = new int[fieldList.size()];
        for (int i = 0; i < fieldList.size(); ++i) {
            project[i] = fieldList.get(i).getIndex();
        }
        return new BindableTableScan(tableScan.table, project, null, tableScan.getRowType());
    }

    private static BindableNode convertProject(Project project) {
        return new BindableProject(project.getInput(0), project.exprList, project.getRowType());
    }

    private static BindableSort convertSort(Sort sort) {
        return new BindableSort(sort.getInput(0), sort.offset, sort.limit, sort.sortItemList);
    }

    private static BindableJoin convertJoin(Join join) {
        return new BindableJoin(join.getInput(0), join.getInput(1), join.joinType, join.condition);
    }

    private static BindableFilter convertFilter(Filter filter) {
        return new BindableFilter(filter.getInput(0), filter.condition);
    }

    private static BindableAggregate convertAggregate(Aggregate aggregate) {
        return new BindableAggregate(aggregate.getInput(0), aggregate.groupList, aggregate.aggCallList, aggregate.getRowType());
    }

    private static class RexRefMapper
    extends BaseRexNodeVisitor<RexNode> {
        private final HashMap<Integer, Integer> mapping;

        public RexRefMapper(HashMap<Integer, Integer> mapping) {
            this.mapping = mapping;
        }

        @Override
        protected RexNode visitChildren(RexNode rexNode) {
            if (rexNode instanceof RexCall) {
                return rexNode.accept(this);
            }
            if (rexNode instanceof RexNodeList) {
                return rexNode.accept(this);
            }
            if (rexNode instanceof RexInputRef) {
                return rexNode.accept(this);
            }
            if (rexNode instanceof RexSubQuery) {
                return rexNode.accept(this);
            }
            return rexNode;
        }

        @Override
        public RexNode visitRexCall(RexCall call) {
            List<RexNode> list = call.getOperands();
            for (int i = 0; i < list.size(); ++i) {
                RexNode node1 = list.get(i);
                RexNode node2 = node1.accept(this);
                if (node2 == node1) continue;
                call.replaceOperand(i, node2);
            }
            call.buildDigest();
            return call;
        }

        @Override
        public RexNode visitRexInputRef(RexInputRef inputRef) {
            return new RexInputRef(inputRef.getType(), this.mapping.get(inputRef.getIndex()));
        }

        @Override
        public RexNode visitRexNodeList(RexNodeList nodeList) {
            if (nodeList == null || nodeList.isEmpty()) {
                return nodeList;
            }
            for (int i = 0; i < nodeList.size(); ++i) {
                RexNode item2;
                RexNode item1 = nodeList.get(i);
                if (item1 == (item2 = item1.accept(this))) continue;
                nodeList.set(i, item2);
            }
            return nodeList;
        }

        @Override
        public RexNode visitRexSubQuery(RexSubQuery subQuery) {
            BindableNode sub = OptimizeHelp.optimize0(subQuery.getQuery());
            if (sub != subQuery.getQuery()) {
                subQuery.setQuery(sub);
            }
            subQuery.buildDigest();
            return subQuery;
        }
    }

    private static class InputRefCollector
    extends BaseRexNodeVisitor1 {
        private BitSet bitSet = new BitSet();

        public InputRefCollector() {
        }

        public InputRefCollector(BitSet bitSet) {
            this.bitSet = bitSet;
        }

        @Override
        public void visitRexInputRef(RexInputRef inputRef) {
            this.bitSet.set(inputRef.getIndex());
        }
    }

    private static class TrimResult {
        public final RelNode node;
        public final HashMap<Integer, Integer> mapper;

        public TrimResult(RelNode node, HashMap<Integer, Integer> mapper) {
            this.node = node;
            this.mapper = mapper;
        }
    }
}

