/*
 * Decompiled with CFR 0.152.
 */
package kd.bos.xdb.sharding.sql.dml;

import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.statement.SQLUpdateStatement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import kd.bos.xdb.engine.ShardingContext;
import kd.bos.xdb.sharding.ShardingFieldValue;
import kd.bos.xdb.sharding.ShardingGroupTable;
import kd.bos.xdb.sharding.sql.condition.ConditionExprList;
import kd.bos.xdb.sharding.sql.dml.DMLShardingSQL;
import kd.bos.xdb.sharding.sql.dml.update.PrepareUpdate;
import kd.bos.xdb.sharding.sql.parser.StatementInfo;
import kd.bos.xdb.tablemanager.TableName;

public class UpdateShardingSQL
extends DMLShardingSQL {
    public UpdateShardingSQL(StatementInfo stmtInfo) {
        super(stmtInfo);
        ShardingContext ctx = ShardingContext.get();
        if (!ctx.isUpdateShardingFieldAndIndexPrepared()) {
            try {
                new PrepareUpdate(this).prepare();
            }
            finally {
                ctx.setUpdateShardingFieldAndIndexPrepared(true);
            }
        }
    }

    @Override
    protected boolean isFullShardingValueRequired() {
        return false;
    }

    @Override
    protected ConditionExprList collectConditionExprs() {
        ConditionExprList ce = new ConditionExprList();
        SQLUpdateStatement stmt = (SQLUpdateStatement)this.getStatementInfo().getSQLStatement();
        SQLExpr where = stmt.getWhere();
        if (where != null) {
            ce.add(where);
        }
        return ce;
    }

    @Override
    protected ShardingGroupTable[] appendGroups(ShardingGroupTable[] groups) {
        Set<Long> toIndexSet = ShardingContext.get().getUpdateShardingFieldToNewTables();
        if (toIndexSet != null) {
            ArrayList<ShardingGroupTable> list = new ArrayList<ShardingGroupTable>(Arrays.asList(groups));
            ShardingGroupTable prototype = groups[0];
            List<ShardingFieldValue> fieldValues = prototype.getFieldValues();
            TableName tn = TableName.of(prototype.getShardingTable());
            for (long index : toIndexSet) {
                ShardingGroupTable add = new ShardingGroupTable(tn.getShardingTable(index));
                add.addShardingFieldValues(fieldValues);
                add.setShardingHintContext(prototype.getShardingHintContext());
                list.add(add);
            }
            groups = list.toArray(new ShardingGroupTable[list.size()]);
        }
        return groups;
    }
}

