/*
 * Decompiled with CFR 0.152.
 */
package cn.win_trust_erpc.bouncycastle.pqc.crypto.sphincsplus;

import cn.win_trust_erpc.bouncycastle.pqc.crypto.sphincsplus.ADRS;
import cn.win_trust_erpc.bouncycastle.pqc.crypto.sphincsplus.NodeEntry;
import cn.win_trust_erpc.bouncycastle.pqc.crypto.sphincsplus.SIG_FORS;
import cn.win_trust_erpc.bouncycastle.pqc.crypto.sphincsplus.SPHINCSPlusEngine;
import cn.win_trust_erpc.bouncycastle.pqc.crypto.sphincsplus.WotsPlus;
import cn.win_trust_erpc.bouncycastle.util.Arrays;
import java.util.LinkedList;

class Fors {
    private final WotsPlus wots;
    SPHINCSPlusEngine engine;

    public Fors(SPHINCSPlusEngine engine) {
        this.engine = engine;
        this.wots = new WotsPlus(engine);
    }

    byte[] pkGen(byte[] skSeed, byte[] pkSeed, ADRS adrs) {
        ADRS forspkADRS = new ADRS(adrs);
        byte[][] root = new byte[this.engine.K][];
        int i = 0;
        while (i < this.engine.K) {
            root[i] = this.treehash(skSeed, i * this.engine.T, this.engine.A, pkSeed, adrs);
            ++i;
        }
        forspkADRS.setType(4);
        forspkADRS.setKeyPairAddress(adrs.getKeyPairAddress());
        return this.engine.T_l(pkSeed, forspkADRS, Arrays.concatenate(root));
    }

    byte[] treehash(byte[] skSeed, int s, int z, byte[] pkSeed, ADRS adrsParam) {
        ADRS adrs = new ADRS(adrsParam);
        LinkedList<NodeEntry> stack = new LinkedList<NodeEntry>();
        if (s % (1 << z) != 0) {
            return null;
        }
        int idx = 0;
        while (idx < 1 << z) {
            adrs.setTreeHeight(0);
            adrs.setTreeIndex(s + idx);
            byte[] sk = this.engine.PRF(skSeed, adrs);
            byte[] node = this.engine.F(pkSeed, adrs, sk);
            adrs.setTreeHeight(1);
            adrs.setTreeIndex(s + idx);
            while (!stack.isEmpty() && ((NodeEntry)stack.get((int)0)).nodeHeight == adrs.getTreeHeight()) {
                adrs.setTreeIndex((adrs.getTreeIndex() - 1) / 2);
                NodeEntry current = (NodeEntry)stack.remove(0);
                node = this.engine.H(pkSeed, adrs, current.nodeValue, node);
                adrs.setTreeHeight(adrs.getTreeHeight() + 1);
            }
            stack.add(0, new NodeEntry(node, adrs.getTreeHeight()));
            ++idx;
        }
        return ((NodeEntry)stack.get((int)0)).nodeValue;
    }

    public SIG_FORS[] sign(byte[] md, byte[] skSeed, byte[] pkSeed, ADRS adrs) {
        int[] idxs = Fors.message_to_idxs(md, this.engine.K, this.engine.A);
        SIG_FORS[] sig_fors = new SIG_FORS[this.engine.K];
        int t = this.engine.T;
        int i = 0;
        while (i < this.engine.K) {
            int idx = idxs[i];
            adrs.setTreeHeight(0);
            adrs.setTreeIndex(i * t + idx);
            byte[] sk = this.engine.PRF(skSeed, adrs);
            byte[][] authPath = new byte[this.engine.A][];
            int j = 0;
            while (j < this.engine.A) {
                int s = idx / (1 << j) ^ 1;
                authPath[j] = this.treehash(skSeed, i * t + s * (1 << j), j, pkSeed, adrs);
                ++j;
            }
            sig_fors[i] = new SIG_FORS(sk, authPath);
            ++i;
        }
        return sig_fors;
    }

    public byte[] pkFromSig(SIG_FORS[] sig_fors, byte[] message, byte[] pkSeed, ADRS adrs) {
        byte[][] node = new byte[2][];
        byte[][] root = new byte[this.engine.K][];
        int t = this.engine.T;
        int[] idxs = Fors.message_to_idxs(message, this.engine.K, this.engine.A);
        int i = 0;
        while (i < this.engine.K) {
            int idx = idxs[i];
            byte[] sk = sig_fors[i].getSK();
            adrs.setTreeHeight(0);
            adrs.setTreeIndex(i * t + idx);
            node[0] = this.engine.F(pkSeed, adrs, sk);
            byte[][] authPath = sig_fors[i].getAuthPath();
            adrs.setTreeIndex(i * t + idx);
            int j = 0;
            while (j < this.engine.A) {
                adrs.setTreeHeight(j + 1);
                if (idx / (1 << j) % 2 == 0) {
                    adrs.setTreeIndex(adrs.getTreeIndex() / 2);
                    node[1] = this.engine.H(pkSeed, adrs, node[0], authPath[j]);
                } else {
                    adrs.setTreeIndex((adrs.getTreeIndex() - 1) / 2);
                    node[1] = this.engine.H(pkSeed, adrs, authPath[j], node[0]);
                }
                node[0] = node[1];
                ++j;
            }
            root[i] = node[0];
            ++i;
        }
        ADRS forspkADRS = new ADRS(adrs);
        forspkADRS.setType(4);
        forspkADRS.setKeyPairAddress(adrs.getKeyPairAddress());
        return this.engine.T_l(pkSeed, forspkADRS, Arrays.concatenate(root));
    }

    static int[] message_to_idxs(byte[] msg, int fors_trees, int fors_height) {
        int offset = 0;
        int[] idxs = new int[fors_trees];
        int i = 0;
        while (i < fors_trees) {
            idxs[i] = 0;
            int j = 0;
            while (j < fors_height) {
                int n = i;
                idxs[n] = idxs[n] ^ (msg[offset >> 3] >> (offset & 7) & 1) << j;
                ++offset;
                ++j;
            }
            ++i;
        }
        return idxs;
    }
}

