/*
 * Decompiled with CFR 0.152.
 */
package kd.bos.flydb.core.sql.operator;

import kd.bos.flydb.common.exception.ErrorCode;
import kd.bos.flydb.common.exception.Exceptions;
import kd.bos.flydb.core.sql.operator.OperandTypeInferences;
import kd.bos.flydb.core.sql.operator.ReturnTypeInference;
import kd.bos.flydb.core.sql.operator.SqlFunctionOperatorImpl;
import kd.bos.flydb.core.sql.tree.SqlBasicCall;
import kd.bos.flydb.core.sql.tree.SqlKind;
import kd.bos.flydb.core.sql.tree.SqlNode;
import kd.bos.flydb.core.sql.type.DataType;
import kd.bos.flydb.core.sql.type.SqlTypeCategory;
import kd.bos.flydb.core.sql.validate.SqlValidator;
import kd.bos.flydb.core.sql.validate.SqlValidatorScope;

public class SqlAggregateFunctionOperator
extends SqlFunctionOperatorImpl {
    public SqlAggregateFunctionOperator(String name, SqlKind sqlKind) {
        super(name, sqlKind, OperandTypeInferences.FIRST_KNOWN, new AggregateReturnTypeInference(sqlKind));
    }

    @Override
    public void checkOperandCount(SqlValidator sqlValidator, SqlValidatorScope scope, SqlNode sqlNode) {
        SqlBasicCall call = sqlNode.cast(SqlBasicCall.class);
        if (call.getOperandCount() != 1) {
            throw Exceptions.of((ErrorCode)ErrorCode.AggregateFunctionRequireArgument, (Object[])new Object[]{this.name(), 1});
        }
        super.checkOperandCount(sqlValidator, scope, sqlNode);
    }

    @Override
    public void checkOperandType(SqlValidator sqlValidator, SqlValidatorScope scope, SqlNode sqlNode) {
        if (SqlKind.FUNC_COUNT != this.getKind()) {
            SqlBasicCall call = sqlNode.cast(SqlBasicCall.class);
            for (int i = 0; i < call.getOperandList().size(); ++i) {
                DataType type = sqlValidator.inferDataType(call.getOperand(i), scope);
                if (type.getCategory() == SqlTypeCategory.NUMBER) continue;
                throw Exceptions.of((ErrorCode)ErrorCode.AggregateFunctionRequireArgumentType, (Object[])new Object[]{this.name()});
            }
        }
        super.checkOperandType(sqlValidator, scope, sqlNode);
    }

    private static class AggregateReturnTypeInference
    implements ReturnTypeInference {
        private final SqlKind sqlKind;

        public AggregateReturnTypeInference(SqlKind sqlKind) {
            this.sqlKind = sqlKind;
        }

        @Override
        public DataType inferReturnType(SqlValidator sqlValidator, SqlValidatorScope scope, SqlNode sqlNode) {
            if (this.sqlKind == SqlKind.FUNC_COUNT) {
                return sqlValidator.getTypeFactory().buildLong();
            }
            if (this.sqlKind == SqlKind.FUNC_AVG) {
                return sqlValidator.getTypeFactory().getMaxPrecisionDecimal();
            }
            SqlBasicCall call = sqlNode.cast(SqlBasicCall.class);
            DataType returnType = sqlValidator.inferDataType(call.getOperand(0), scope);
            if (returnType == null) {
                throw Exceptions.of((ErrorCode)ErrorCode.AggregateFunctionRequireReturnType, (Object[])new Object[0]);
            }
            return returnType;
        }
    }
}

