/*
 * Decompiled with CFR 0.152.
 */
package kd.bos.flydb.core.sql.validate.impl;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import kd.bos.flydb.common.ServerOption;
import kd.bos.flydb.common.exception.ErrorCode;
import kd.bos.flydb.common.exception.Exceptions;
import kd.bos.flydb.core.Contexts;
import kd.bos.flydb.core.schema.virtual.InnerVirtualTableInfo;
import kd.bos.flydb.core.sql.operator.SqlOperator;
import kd.bos.flydb.core.sql.operator.SqlOperators;
import kd.bos.flydb.core.sql.tree.SqlBasicCall;
import kd.bos.flydb.core.sql.tree.SqlCall;
import kd.bos.flydb.core.sql.tree.SqlIdentifier;
import kd.bos.flydb.core.sql.tree.SqlJoin;
import kd.bos.flydb.core.sql.tree.SqlJoinType;
import kd.bos.flydb.core.sql.tree.SqlKind;
import kd.bos.flydb.core.sql.tree.SqlLiteral;
import kd.bos.flydb.core.sql.tree.SqlNode;
import kd.bos.flydb.core.sql.tree.SqlNodeList;
import kd.bos.flydb.core.sql.tree.SqlParserPosition;
import kd.bos.flydb.core.sql.tree.SqlSelect;
import kd.bos.flydb.core.sql.util.BaseASTVisitor;
import kd.bos.flydb.core.sql.util.Pair;
import kd.bos.logging.Log;
import kd.bos.logging.LogFactory;

public class SqlNodeOptimizer {
    private static final Log logger = LogFactory.getLog(SqlNodeOptimizer.class);
    public static final SqlNodeOptimizer instance = new SqlNodeOptimizer();
    private SqlNodeInSubPredicate predicate = new SqlNodeInSubPredicate();
    private SqlNodeInVariableRewrite variablePredicate = new SqlNodeInVariableRewrite(this);
    private Map<SqlNode, InnerVirtualTableInfo> sqlNodeInnerVirtualTableInfoMap = new IdentityHashMap<SqlNode, InnerVirtualTableInfo>();

    public Map<SqlNode, InnerVirtualTableInfo> getSqlNodeInnerVirtualTableInfoMap() {
        return this.sqlNodeInnerVirtualTableInfoMap;
    }

    public SqlNode optimize(SqlNode sqlNode) {
        if (Boolean.parseBoolean(Contexts.get().getConfig("flydb.sqlOptimizeInVariable")) && sqlNode.getKind() == SqlKind.SELECT) {
            SqlSelect node = sqlNode.cast(SqlSelect.class);
            SqlNode from = node.getOperand(2);
            SqlNode where = node.getOperand(3);
            if (from == null || where == null) {
                return node;
            }
            if (where.cast(SqlBasicCall.class).getKind() != SqlKind.IN) {
                return node;
            }
            InnerVirtualTableInfo virtualTable = node.accept(this.variablePredicate);
            Pair<Boolean, List<SqlNode>> predicate = where.accept(this.predicate);
            if (predicate != null && predicate.getType().booleanValue()) {
                List<SqlNode> inSubSqlSelects = predicate.getValue();
                for (SqlNode inSubSqlSelect : inSubSqlSelects) {
                    ThreeTuple<SqlNode, SqlNode, SqlNode> unRelateSubQueryInfo = this.getUnRelateSubQueryInfo(inSubSqlSelect);
                    if (unRelateSubQueryInfo == null) continue;
                    logger.info(String.format("sql node optimize, original sql node toSql='%s'", sqlNode.toSql()));
                    List<String> tableAliases = this.getTableAliases(from);
                    String alias = tableAliases.isEmpty() ? ((SqlIdentifier)from).getLast() : tableAliases.get(0);
                    this.rewriteWhereNode(node, where, inSubSqlSelect, alias);
                    if (tableAliases.isEmpty() && from instanceof SqlIdentifier) {
                        SqlParserPosition position = from.getPosition();
                        from = new SqlBasicCall(position, SqlKind.AS, SqlOperators.of(SqlKind.AS), from, new SqlIdentifier(new SqlParserPosition(position.getLine(), position.getColumn() + 4, position.getEndLine(), position.getEndColumn() + 4), Collections.singletonList(alias)));
                        SqlBasicCall inQueryBasicCall = inSubSqlSelect.cast(SqlBasicCall.class);
                        SqlIdentifier identifier = inQueryBasicCall.getOperand(0).cast(SqlIdentifier.class);
                        List<String> names = identifier.getNames();
                        names.add(0, alias);
                        inQueryBasicCall.setOperand(0, new SqlIdentifier(identifier.getPosition(), names));
                    }
                    SqlNode tableNode = unRelateSubQueryInfo.getT1();
                    SqlNode selectNode = unRelateSubQueryInfo.getT2();
                    List<String> subTableAliases = this.getTableAliases(tableNode);
                    SqlParserPosition position = inSubSqlSelect.getPosition();
                    position = new SqlParserPosition(position.getLine(), position.getColumn() + 2, position.getEndLine(), position.getEndColumn() + 4);
                    String subSelectAlias = subTableAliases.isEmpty() ? tableNode.cast(SqlIdentifier.class).getLast() : subTableAliases.get(0);
                    inSubSqlSelect.cast(SqlBasicCall.class).setOperand(1, new SqlBasicCall(position, SqlKind.AS, SqlOperators.of(SqlKind.AS), inSubSqlSelect.cast(SqlBasicCall.class).getOperand(1), new SqlIdentifier(position, Collections.singletonList(subSelectAlias))));
                    SqlIdentifier identifier = null;
                    if (selectNode.getKind() == SqlKind.IDENTIFIER) {
                        identifier = (SqlIdentifier)selectNode;
                    } else if (selectNode.getKind() == SqlKind.AS) {
                        SqlBasicCall call = (SqlBasicCall)selectNode;
                        identifier = (SqlIdentifier)call.getOperand(1);
                    } else {
                        identifier = selectNode.cast(SqlIdentifier.class);
                    }
                    List<String> relateTables = this.getRelateTables(identifier);
                    if (relateTables.isEmpty()) {
                        List<String> names = identifier.getNames();
                        names.add(0, subSelectAlias);
                        identifier = new SqlIdentifier(identifier.getPosition(), names);
                    }
                    SqlJoin sqlJoin = new SqlJoin(from.getPosition(), from, SqlJoinType.INNER.symbol(from.getPosition()), new SqlLiteral(new SqlParserPosition(position.getLine(), position.getEndColumn() + 2, position.getEndLine(), position.getEndColumn() + 4), "ON"), inSubSqlSelect.cast(SqlBasicCall.class).getOperand(1), new SqlBasicCall(from.getPosition(), SqlKind.EQUALS, SqlOperators.of(SqlKind.EQUALS), inSubSqlSelect.cast(SqlBasicCall.class).getOperand(0), identifier));
                    node.setOperand(2, sqlJoin);
                    logger.info(String.format("sql node optimize, optimize sql node toSql='%s'", sqlNode.toSql()));
                }
            }
        }
        return sqlNode;
    }

    private void rewriteWhereNode(SqlSelect sqlSelect, SqlNode where, SqlNode inSubSqlSelect, String alias) {
        if (where instanceof SqlBasicCall) {
            SqlBasicCall sqlBasicCall = where.cast(SqlBasicCall.class);
            if (sqlBasicCall.equals(inSubSqlSelect)) {
                sqlSelect.setOperand(3, null);
                return;
            }
            sqlBasicCall = this.doRewrite(sqlBasicCall, inSubSqlSelect, alias);
            sqlSelect.setOperand(3, sqlBasicCall);
        }
    }

    private SqlBasicCall doRewrite(SqlBasicCall sqlBasicCall, SqlNode inSubSqlSelect, String alias) {
        List<SqlNode> operandList = sqlBasicCall.getOperandList();
        Iterator<SqlNode> iterator = operandList.iterator();
        while (iterator.hasNext()) {
            SqlNode sqlNode = iterator.next();
            if (sqlNode instanceof SqlBasicCall) {
                if (sqlNode.equals(inSubSqlSelect)) {
                    iterator.remove();
                    continue;
                }
                SqlBasicCall basicCall = sqlNode.cast(SqlBasicCall.class);
                SqlNode operand = basicCall.getOperand(0);
                SqlIdentifier identifier = null;
                if (operand.getKind() == SqlKind.IDENTIFIER) {
                    identifier = (SqlIdentifier)operand;
                } else if (operand.getKind() == SqlKind.AS) {
                    SqlBasicCall call = (SqlBasicCall)operand;
                    identifier = (SqlIdentifier)call.getOperand(1);
                } else {
                    identifier = operand.cast(SqlIdentifier.class);
                }
                List<String> relateTables = this.getRelateTables(identifier);
                if (relateTables.isEmpty()) {
                    List<String> names = identifier.getNames();
                    names.add(0, alias);
                    identifier = new SqlIdentifier(identifier.getPosition(), names);
                    basicCall.setOperand(0, identifier);
                }
                this.doRewrite(basicCall, inSubSqlSelect, alias);
                continue;
            }
            return null;
        }
        return sqlBasicCall.getOperandList().get(0).cast(SqlBasicCall.class);
    }

    private ThreeTuple<SqlNode, SqlNode, SqlNode> getUnRelateSubQueryInfo(SqlNode inSubSqlSelect) {
        if (inSubSqlSelect instanceof SqlBasicCall) {
            SqlBasicCall basicCall = inSubSqlSelect.cast(SqlBasicCall.class);
            SqlSelect select = basicCall.getOperand(1).cast(SqlSelect.class);
            SqlNode from = (select = this.optimize(select).cast(SqlSelect.class)).getFrom();
            if (from == null) {
                return null;
            }
            SqlNodeList sqlNodeList = select.getOperand(1).cast(SqlNodeList.class);
            if (sqlNodeList.size() == 1) {
                SqlIdentifier identifier;
                SqlNode node = sqlNodeList.get(0);
                if (node.getKind() == SqlKind.IDENTIFIER) {
                    identifier = (SqlIdentifier)node;
                } else if (node.getKind() == SqlKind.AS) {
                    SqlBasicCall call = (SqlBasicCall)node;
                    identifier = (SqlIdentifier)call.getOperand(0);
                } else {
                    return null;
                }
                if (identifier.isStar()) {
                    return null;
                }
                List<String> tableAliases = this.getTableAliases(from);
                List<String> relateTables = this.getRelateTables(identifier);
                SqlNode where = select.getWhere();
                if (where != null) {
                    relateTables.addAll((Collection<String>)where.cast(SqlBasicCall.class).accept(new BaseASTVisitor<List<String>>(){

                        @Override
                        public List<String> visitSqlBasicCall(SqlBasicCall node) {
                            SqlIdentifier sqlIdentifier;
                            ArrayList<String> alias = new ArrayList<String>();
                            if (node.getOperand(0).getKind() == SqlKind.AS) {
                                sqlIdentifier = node.getOperand(0).cast(SqlBasicCall.class).getOperand(0).cast(SqlIdentifier.class);
                            } else if (node.getOperand(0).getKind() == SqlKind.IDENTIFIER) {
                                sqlIdentifier = node.getOperand(0).cast(SqlIdentifier.class);
                            } else {
                                return alias;
                            }
                            return SqlNodeOptimizer.this.getRelateTables(sqlIdentifier);
                        }
                    }));
                }
                if (relateTables.isEmpty()) {
                    return new ThreeTuple<SqlNode, SqlNode, SqlNode>(from, node, where);
                }
                if (!tableAliases.isEmpty()) {
                    relateTables.removeAll(tableAliases);
                    if (relateTables.isEmpty()) {
                        return new ThreeTuple<SqlNode, SqlNode, SqlNode>(from, node, where);
                    }
                }
            }
        }
        return null;
    }

    private List<String> getRelateTables(SqlIdentifier identifier) {
        ArrayList<String> relateTables = new ArrayList<String>(1);
        List<String> names = identifier.getNames();
        int size = names.size();
        if (size > 1) {
            names.remove(--size);
            relateTables.add(String.join((CharSequence)".", names));
        }
        return relateTables;
    }

    private List<String> getTableAliases(SqlNode from) {
        ArrayList<String> tableAliases = new ArrayList<String>();
        if (from.getKind() == SqlKind.JOIN) {
            SqlJoin sqlJoin = from.cast(SqlJoin.class);
            tableAliases.add(sqlJoin.getLeft().cast(SqlBasicCall.class).getOperand(1).cast(SqlIdentifier.class).getLast());
            tableAliases.add(sqlJoin.getRight().cast(SqlBasicCall.class).getOperand(1).cast(SqlIdentifier.class).getLast());
        } else if (from.getKind() == SqlKind.AS) {
            tableAliases.add(from.cast(SqlBasicCall.class).getOperand(1).cast(SqlIdentifier.class).getLast());
        }
        return tableAliases;
    }

    private static class ThreeTuple<T1, T2, T3> {
        private T1 t1;
        private T2 t2;
        private T3 t3;

        public ThreeTuple(T1 t1, T2 t2, T3 t3) {
            this.t1 = t1;
            this.t2 = t2;
            this.t3 = t3;
        }

        public T1 getT1() {
            return this.t1;
        }

        public T2 getT2() {
            return this.t2;
        }

        public T3 getT3() {
            return this.t3;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            ThreeTuple that = (ThreeTuple)o;
            return Objects.equals(this.t1, that.t1) && Objects.equals(this.t2, that.t2) && Objects.equals(this.t3, that.t3);
        }

        public int hashCode() {
            return Objects.hash(this.t1, this.t2, this.t3);
        }
    }

    private static class SqlNodeInVariableRewrite
    extends BaseASTVisitor<InnerVirtualTableInfo> {
        SqlNodeOptimizer sqlNodeOptimizer;

        public SqlNodeInVariableRewrite(SqlNodeOptimizer sqlNodeOptimizer) {
            this.sqlNodeOptimizer = sqlNodeOptimizer;
        }

        @Override
        public InnerVirtualTableInfo visitChildren(SqlNode node) {
            InnerVirtualTableInfo result = (InnerVirtualTableInfo)this.defaultResult();
            if (node instanceof SqlCall) {
                SqlCall call = (SqlCall)node;
                List<SqlNode> operandList = call.getOperandList();
                for (int i = 0; i < call.getOperandCount() && this.shouldVisitNextChild(node, result); ++i) {
                    InnerVirtualTableInfo next;
                    SqlNode nextNode = operandList.get(i);
                    if (nextNode == null || (next = nextNode.accept(this)) == null) continue;
                    result = this.aggregateResult(result, next);
                }
            }
            return result;
        }

        @Override
        public InnerVirtualTableInfo visitSqlBasicCall(SqlBasicCall node) {
            SqlNode operand;
            if (node.getKind() == SqlKind.IN && (operand = node.getOperand(1)) instanceof SqlNodeList) {
                SqlNodeList nodeList = operand.cast(SqlNodeList.class);
                int size = nodeList.size();
                int opt2Or = Integer.parseInt(this.getStringFromConfig(ServerOption.SqlOptimizeInVariable2OrThreshold));
                int opt2Join = Integer.parseInt(this.getStringFromConfig(ServerOption.SqlOptimizeInVariable2JoinThreshold));
                int optMaxSize = Integer.parseInt(this.getStringFromConfig(ServerOption.SqlOptimizeInVariableMaxSize));
                if (size > optMaxSize) {
                    throw Exceptions.of((ErrorCode)ErrorCode.OptimizeError1, (Object[])new Object[]{optMaxSize, size});
                }
                if (size > opt2Join) {
                    return this.getInnerVirtualTable(node, nodeList);
                }
                if (size < opt2Or) {
                    this.optToOr(node, nodeList);
                }
            }
            return (InnerVirtualTableInfo)super.visitSqlBasicCall(node);
        }

        private InnerVirtualTableInfo getInnerVirtualTable(SqlBasicCall node, SqlNodeList nodeList) {
            SqlParserPosition position = nodeList.getPosition();
            SqlNodeList selectList = new SqlNodeList(new SqlParserPosition(position.getLine(), position.getColumn() + 7 + "$expr_inner_field$".length(), position.getEndLine(), position.getEndColumn() + 7 + "$expr_inner_field$".length()), Collections.singletonList(new SqlIdentifier(new SqlParserPosition(position.getLine(), position.getColumn() + 7 + "$expr_inner_field$".length(), position.getEndLine(), position.getEndColumn() + 7 + "$expr_inner_field$".length()), Collections.singletonList("$expr_inner_field$"))));
            SqlIdentifier from = new SqlIdentifier(new SqlParserPosition(position.getLine(), position.getColumn() + 13 + "$expr_inner_field$".length() + "$$inner_virtual_table$$".length(), position.getEndLine(), position.getEndColumn() + 13 + "$expr_inner_field$".length() + "$$inner_virtual_table$$".length()), Collections.singletonList("$$inner_virtual_table$$"));
            SqlSelect sqlSelect = new SqlSelect(position, null, selectList, (SqlNode)from, null, null, null, null, null, null, null, null);
            node.setOperand(1, sqlSelect);
            InnerVirtualTableInfo innerVirtualTableInfo = new InnerVirtualTableInfo(nodeList);
            this.sqlNodeOptimizer.getSqlNodeInnerVirtualTableInfoMap().put(from, innerVirtualTableInfo);
            return innerVirtualTableInfo;
        }

        private void optToOr(SqlBasicCall node, SqlNodeList nodeList) {
            logger.info(String.format("sql node optimize, optimize sql node ='%s'", node.toSql()));
            SqlNode left = node.getOperand(0);
            SqlNode leftCondition = nodeList.get(0);
            SqlOperator equalsOp = SqlOperators.of(SqlKind.EQUALS);
            SqlOperator orOp = SqlOperators.of(SqlKind.OR);
            if (nodeList.size() > 1) {
                nodeList.remove(0);
                SqlBasicCall leftCall = new SqlBasicCall(left.getPosition(), SqlKind.EQUALS, equalsOp, left, leftCondition);
                SqlBasicCall mockBasicCall = new SqlBasicCall(left.getPosition(), SqlKind.OR, SqlOperators.of(SqlKind.OR), leftCall, leftCall);
                SqlBasicCall rightBasicCall = this.inConvert2Or(mockBasicCall, left, nodeList);
                if (rightBasicCall != null) {
                    node.setOperand(0, rightBasicCall.getOperand(0));
                    node.setOperand(1, rightBasicCall.getOperand(1));
                    node.setOperator(orOp);
                    logger.info(String.format("sql node optimize, optimize sql node toSql='%s'", node.toSql()));
                }
            }
        }

        private String getStringFromConfig(ServerOption serverOption) {
            String value = Contexts.get().getConfig(serverOption.key());
            return value != null ? value : serverOption.defaultValue();
        }

        private SqlBasicCall inConvert2Or(SqlBasicCall mockBasicCall, SqlNode left, SqlNodeList nodeList) {
            Iterator<SqlNode> iterator = nodeList.iterator();
            if (iterator.hasNext()) {
                SqlNode next = iterator.next();
                iterator.remove();
                SqlBasicCall temp = mockBasicCall.getOperand(0).cast(SqlBasicCall.class);
                SqlBasicCall leftCall = new SqlBasicCall(left.getPosition(), SqlKind.EQUALS, SqlOperators.of(SqlKind.EQUALS), left, next);
                mockBasicCall.setOperand(0, leftCall);
                SqlBasicCall rightCall = this.inConvert2Or(mockBasicCall, left, nodeList);
                if (rightCall == null) {
                    mockBasicCall.setOperand(0, temp);
                    mockBasicCall.setOperand(1, leftCall);
                    return mockBasicCall;
                }
                return new SqlBasicCall(left.getPosition(), SqlKind.OR, SqlOperators.of(SqlKind.OR), temp, rightCall);
            }
            return null;
        }
    }

    private static class SqlNodeInSubPredicate
    extends BaseASTVisitor<Pair<Boolean, List<SqlNode>>> {
        private SqlNodeInSubPredicate() {
        }

        @Override
        public Pair<Boolean, List<SqlNode>> visitChildren(SqlNode node) {
            Pair<Boolean, List<SqlNode>> result = (Pair<Boolean, List<SqlNode>>)this.defaultResult();
            if (node instanceof SqlCall) {
                SqlCall call = (SqlCall)node;
                List<SqlNode> operandList = call.getOperandList();
                ArrayList inSqlSelects = new ArrayList(operandList.size());
                for (int i = 0; i < call.getOperandCount() && this.shouldVisitNextChild(node, result); ++i) {
                    Pair<Boolean, List<SqlNode>> next;
                    SqlNode nextNode = operandList.get(i);
                    if (nextNode == null || (next = nextNode.accept(this)) == null) continue;
                    inSqlSelects.addAll(next.getValue());
                    result = this.aggregateResult(result, next);
                }
                if (inSqlSelects.isEmpty()) {
                    return result;
                }
                return Pair.of(true, inSqlSelects);
            }
            return result;
        }

        @Override
        public Pair<Boolean, List<SqlNode>> visitSqlBasicCall(SqlBasicCall node) {
            SqlNode operand;
            if (node.getKind() == SqlKind.IN && (operand = node.getOperand(1)) instanceof SqlSelect) {
                return Pair.of(true, Collections.singletonList(node));
            }
            return (Pair)super.visitSqlBasicCall(node);
        }
    }
}

