/*
 * Decompiled with CFR 0.152.
 */
package kd.bos.flydb.core.sql.operator;

import kd.bos.flydb.common.exception.ErrorCode;
import kd.bos.flydb.common.exception.Exceptions;
import kd.bos.flydb.core.sql.operator.OperandTypeInferences;
import kd.bos.flydb.core.sql.operator.ReturnTypeInference;
import kd.bos.flydb.core.sql.operator.SqlOperatorImpl;
import kd.bos.flydb.core.sql.tree.SqlBasicCall;
import kd.bos.flydb.core.sql.tree.SqlCall;
import kd.bos.flydb.core.sql.tree.SqlCaseWhenType;
import kd.bos.flydb.core.sql.tree.SqlKind;
import kd.bos.flydb.core.sql.tree.SqlLiteral;
import kd.bos.flydb.core.sql.tree.SqlNode;
import kd.bos.flydb.core.sql.tree.SqlNodeList;
import kd.bos.flydb.core.sql.type.DataType;
import kd.bos.flydb.core.sql.unparse.SqlWriter;
import kd.bos.flydb.core.sql.util.Pair;
import kd.bos.flydb.core.sql.validate.SqlValidator;
import kd.bos.flydb.core.sql.validate.SqlValidatorScope;

public class SqlCaseOperator
extends SqlOperatorImpl {
    public SqlCaseOperator() {
        super(SqlKind.CASE.getName(), SqlKind.CASE, 200, true, OperandTypeInferences.ANY, (ReturnTypeInference)new SqlCaseWhenReturnTypeInference());
    }

    @Override
    public void checkOperandType(SqlValidator sqlValidator, SqlValidatorScope scope, SqlNode sqlNode) {
        DataType otherType;
        SqlBasicCall caseWhen = sqlNode.cast(SqlBasicCall.class);
        SqlLiteral caseWhenType = caseWhen.getOperand(0).cast(SqlLiteral.class);
        SqlCaseWhenType type = SqlCaseWhenType.valueOf(caseWhenType.getValue().toString());
        int thenOperand = 3;
        int elseOperand = 4;
        if (type == SqlCaseWhenType.SearchedCase) {
            thenOperand = 2;
            elseOperand = 3;
        }
        SqlNodeList thens = caseWhen.getOperand(thenOperand).cast(SqlNodeList.class);
        DataType dataType = sqlValidator.inferDataType(thens.get(0), scope);
        for (int i = 1; i < thens.size(); ++i) {
            otherType = sqlValidator.inferDataType(thens.get(i), scope);
            if (dataType.equals(otherType)) continue;
            throw Exceptions.of((ErrorCode)ErrorCode.CaseWhenOperatorTypeNotMatched, (Object[])new Object[0]);
        }
        SqlNode elseNode = caseWhen.getOperand(elseOperand);
        if (elseNode != null && !dataType.equals(otherType = sqlValidator.inferDataType(elseNode, scope))) {
            throw Exceptions.of((ErrorCode)ErrorCode.CaseWhenOperatorTypeNotMatched, (Object[])new Object[0]);
        }
        if (type == SqlCaseWhenType.SimpleCase) {
            SqlNode valueNode = caseWhen.getOperand(1);
            DataType valueType = sqlValidator.inferDataType(valueNode, scope);
            SqlNodeList whens = caseWhen.getOperand(2).cast(SqlNodeList.class);
            for (SqlNode when : whens) {
                DataType otherWhenType = sqlValidator.inferDataType(when, scope);
                if (valueType.equals(otherWhenType)) continue;
                throw Exceptions.of((ErrorCode)ErrorCode.CaseWhenOperatorTypeNotMatched, (Object[])new Object[0]);
            }
        }
    }

    @Override
    public void unParse(SqlWriter writer, SqlCall sqlCall, int leftPrecedence, int rightPrecedence) {
        SqlBasicCall sqlCase = sqlCall.cast(SqlBasicCall.class);
        SqlLiteral caseType = sqlCase.getOperand(0).cast(SqlLiteral.class);
        SqlCaseWhenType caseWhenType = SqlCaseWhenType.SimpleCase.name().equals(caseType.getValue()) ? SqlCaseWhenType.SimpleCase : SqlCaseWhenType.SearchedCase;
        int whenOperand = 2;
        int thenOperand = 3;
        int elseOperand = 4;
        if (caseWhenType == SqlCaseWhenType.SearchedCase) {
            whenOperand = 1;
            thenOperand = 2;
            elseOperand = 3;
        }
        SqlWriter.Frame frame = writer.startList(SqlWriter.FrameTypeEnum.CASE, "CASE", "END");
        SqlNodeList whenList = sqlCase.getOperand(whenOperand).cast(SqlNodeList.class);
        SqlNodeList thenList = sqlCase.getOperand(thenOperand).cast(SqlNodeList.class);
        assert (whenList.size() == thenList.size());
        if (caseWhenType == SqlCaseWhenType.SimpleCase) {
            sqlCase.getOperand(1).unParse(writer, 0, 0);
        }
        for (Pair<SqlNode, SqlNode> pair : Pair.zip(whenList, thenList)) {
            writer.sep("WHEN");
            pair.getType().unParse(writer, 0, 0);
            writer.sep("THEN");
            pair.getValue().unParse(writer, 0, 0);
        }
        SqlNode elseExpr = sqlCase.getOperand(elseOperand);
        if (elseExpr != null) {
            writer.sep("ELSE");
            elseExpr.unParse(writer, 0, 0);
        }
        writer.endList(frame);
    }

    public static class SqlCaseWhenReturnTypeInference
    implements ReturnTypeInference {
        @Override
        public DataType inferReturnType(SqlValidator sqlValidator, SqlValidatorScope scope, SqlNode sqlNode) {
            SqlBasicCall caseWhen = sqlNode.cast(SqlBasicCall.class);
            SqlLiteral caseWhenType = caseWhen.getOperand(0).cast(SqlLiteral.class);
            SqlCaseWhenType type = SqlCaseWhenType.valueOf(caseWhenType.getValue().toString());
            int thenOperand = 3;
            if (type == SqlCaseWhenType.SearchedCase) {
                thenOperand = 2;
            }
            SqlNodeList thens = caseWhen.getOperand(thenOperand).cast(SqlNodeList.class);
            return sqlValidator.inferDataType(thens.get(0), scope);
        }
    }
}

