/*
 * Decompiled with CFR 0.152.
 */
package cn.win_trust_erpc.bouncycastle.pqc.crypto.xmss;

import cn.win_trust_erpc.bouncycastle.pqc.crypto.xmss.KeyedHashFunctions;
import cn.win_trust_erpc.bouncycastle.pqc.crypto.xmss.OTSHashAddress;
import cn.win_trust_erpc.bouncycastle.pqc.crypto.xmss.WOTSPlusParameters;
import cn.win_trust_erpc.bouncycastle.pqc.crypto.xmss.WOTSPlusPrivateKeyParameters;
import cn.win_trust_erpc.bouncycastle.pqc.crypto.xmss.WOTSPlusPublicKeyParameters;
import cn.win_trust_erpc.bouncycastle.pqc.crypto.xmss.WOTSPlusSignature;
import cn.win_trust_erpc.bouncycastle.pqc.crypto.xmss.XMSSUtil;
import cn.win_trust_erpc.bouncycastle.util.Arrays;
import java.util.ArrayList;
import java.util.List;

final class WOTSPlus {
    private final WOTSPlusParameters params;
    private final KeyedHashFunctions khf;
    private byte[] secretKeySeed;
    private byte[] publicSeed;

    WOTSPlus(WOTSPlusParameters params) {
        if (params == null) {
            throw new NullPointerException("params == null");
        }
        this.params = params;
        int n = params.getTreeDigestSize();
        this.khf = new KeyedHashFunctions(params.getTreeDigest(), n);
        this.secretKeySeed = new byte[n];
        this.publicSeed = new byte[n];
    }

    void importKeys(byte[] secretKeySeed, byte[] publicSeed) {
        if (secretKeySeed == null) {
            throw new NullPointerException("secretKeySeed == null");
        }
        if (secretKeySeed.length != this.params.getTreeDigestSize()) {
            throw new IllegalArgumentException("size of secretKeySeed needs to be equal to size of digest");
        }
        if (publicSeed == null) {
            throw new NullPointerException("publicSeed == null");
        }
        if (publicSeed.length != this.params.getTreeDigestSize()) {
            throw new IllegalArgumentException("size of publicSeed needs to be equal to size of digest");
        }
        this.secretKeySeed = secretKeySeed;
        this.publicSeed = publicSeed;
    }

    WOTSPlusSignature sign(byte[] messageDigest, OTSHashAddress otsHashAddress) {
        if (messageDigest == null) {
            throw new NullPointerException("messageDigest == null");
        }
        if (messageDigest.length != this.params.getTreeDigestSize()) {
            throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest");
        }
        if (otsHashAddress == null) {
            throw new NullPointerException("otsHashAddress == null");
        }
        List<Integer> baseWMessage = this.convertToBaseW(messageDigest, this.params.getWinternitzParameter(), this.params.getLen1());
        int checksum = 0;
        int i = 0;
        while (i < this.params.getLen1()) {
            checksum += this.params.getWinternitzParameter() - 1 - baseWMessage.get(i);
            ++i;
        }
        int len2Bytes = (int)Math.ceil((double)(this.params.getLen2() * XMSSUtil.log2(this.params.getWinternitzParameter())) / 8.0);
        List<Integer> baseWChecksum = this.convertToBaseW(XMSSUtil.toBytesBigEndian(checksum <<= 8 - this.params.getLen2() * XMSSUtil.log2(this.params.getWinternitzParameter()) % 8, len2Bytes), this.params.getWinternitzParameter(), this.params.getLen2());
        baseWMessage.addAll(baseWChecksum);
        byte[][] signature = new byte[this.params.getLen()][];
        int i2 = 0;
        while (i2 < this.params.getLen()) {
            otsHashAddress = (OTSHashAddress)((OTSHashAddress.Builder)((OTSHashAddress.Builder)((OTSHashAddress.Builder)new OTSHashAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(i2).withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())).build();
            signature[i2] = this.chain(this.expandSecretKeySeed(i2), 0, baseWMessage.get(i2), otsHashAddress);
            ++i2;
        }
        return new WOTSPlusSignature(this.params, signature);
    }

    WOTSPlusPublicKeyParameters getPublicKeyFromSignature(byte[] messageDigest, WOTSPlusSignature signature, OTSHashAddress otsHashAddress) {
        if (messageDigest == null) {
            throw new NullPointerException("messageDigest == null");
        }
        if (messageDigest.length != this.params.getTreeDigestSize()) {
            throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest");
        }
        if (signature == null) {
            throw new NullPointerException("signature == null");
        }
        if (otsHashAddress == null) {
            throw new NullPointerException("otsHashAddress == null");
        }
        List<Integer> baseWMessage = this.convertToBaseW(messageDigest, this.params.getWinternitzParameter(), this.params.getLen1());
        int checksum = 0;
        int i = 0;
        while (i < this.params.getLen1()) {
            checksum += this.params.getWinternitzParameter() - 1 - baseWMessage.get(i);
            ++i;
        }
        int len2Bytes = (int)Math.ceil((double)(this.params.getLen2() * XMSSUtil.log2(this.params.getWinternitzParameter())) / 8.0);
        List<Integer> baseWChecksum = this.convertToBaseW(XMSSUtil.toBytesBigEndian(checksum <<= 8 - this.params.getLen2() * XMSSUtil.log2(this.params.getWinternitzParameter()) % 8, len2Bytes), this.params.getWinternitzParameter(), this.params.getLen2());
        baseWMessage.addAll(baseWChecksum);
        byte[][] publicKey = new byte[this.params.getLen()][];
        int i2 = 0;
        while (i2 < this.params.getLen()) {
            otsHashAddress = (OTSHashAddress)((OTSHashAddress.Builder)((OTSHashAddress.Builder)((OTSHashAddress.Builder)new OTSHashAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(i2).withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())).build();
            publicKey[i2] = this.chain(signature.toByteArray()[i2], baseWMessage.get(i2), this.params.getWinternitzParameter() - 1 - baseWMessage.get(i2), otsHashAddress);
            ++i2;
        }
        return new WOTSPlusPublicKeyParameters(this.params, publicKey);
    }

    private byte[] chain(byte[] startHash, int startIndex, int steps, OTSHashAddress otsHashAddress) {
        int n = this.params.getTreeDigestSize();
        if (startHash == null) {
            throw new NullPointerException("startHash == null");
        }
        if (startHash.length != n) {
            throw new IllegalArgumentException("startHash needs to be " + n + "bytes");
        }
        if (otsHashAddress == null) {
            throw new NullPointerException("otsHashAddress == null");
        }
        if (otsHashAddress.toByteArray() == null) {
            throw new NullPointerException("otsHashAddress byte array == null");
        }
        if (startIndex + steps > this.params.getWinternitzParameter() - 1) {
            throw new IllegalArgumentException("max chain length must not be greater than w");
        }
        if (steps == 0) {
            return startHash;
        }
        byte[] tmp = this.chain(startHash, startIndex, steps - 1, otsHashAddress);
        otsHashAddress = (OTSHashAddress)((OTSHashAddress.Builder)((OTSHashAddress.Builder)((OTSHashAddress.Builder)new OTSHashAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(otsHashAddress.getChainAddress()).withHashAddress(startIndex + steps - 1).withKeyAndMask(0)).build();
        byte[] key = this.khf.PRF(this.publicSeed, otsHashAddress.toByteArray());
        otsHashAddress = (OTSHashAddress)((OTSHashAddress.Builder)((OTSHashAddress.Builder)((OTSHashAddress.Builder)new OTSHashAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(otsHashAddress.getChainAddress()).withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(1)).build();
        byte[] bitmask = this.khf.PRF(this.publicSeed, otsHashAddress.toByteArray());
        byte[] tmpMasked = new byte[n];
        int i = 0;
        while (i < n) {
            tmpMasked[i] = (byte)(tmp[i] ^ bitmask[i]);
            ++i;
        }
        tmp = this.khf.F(key, tmpMasked);
        return tmp;
    }

    private List<Integer> convertToBaseW(byte[] messageDigest, int w, int outLength) {
        if (messageDigest == null) {
            throw new NullPointerException("msg == null");
        }
        if (w != 4 && w != 16) {
            throw new IllegalArgumentException("w needs to be 4 or 16");
        }
        int logW = XMSSUtil.log2(w);
        if (outLength > 8 * messageDigest.length / logW) {
            throw new IllegalArgumentException("outLength too big");
        }
        ArrayList<Integer> res = new ArrayList<Integer>();
        int i = 0;
        while (i < messageDigest.length) {
            int j = 8 - logW;
            while (j >= 0) {
                res.add(messageDigest[i] >> j & w - 1);
                if (res.size() == outLength) {
                    return res;
                }
                j -= logW;
            }
            ++i;
        }
        return res;
    }

    protected byte[] getWOTSPlusSecretKey(byte[] secretKeySeed, OTSHashAddress otsHashAddress) {
        otsHashAddress = (OTSHashAddress)((OTSHashAddress.Builder)((OTSHashAddress.Builder)new OTSHashAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).withOTSAddress(otsHashAddress.getOTSAddress()).build();
        return this.khf.PRF(secretKeySeed, otsHashAddress.toByteArray());
    }

    private byte[] expandSecretKeySeed(int index) {
        if (index < 0 || index >= this.params.getLen()) {
            throw new IllegalArgumentException("index out of bounds");
        }
        return this.khf.PRF(this.secretKeySeed, XMSSUtil.toBytesBigEndian(index, 32));
    }

    protected WOTSPlusParameters getParams() {
        return this.params;
    }

    protected KeyedHashFunctions getKhf() {
        return this.khf;
    }

    protected byte[] getSecretKeySeed() {
        return Arrays.clone(this.secretKeySeed);
    }

    protected byte[] getPublicSeed() {
        return Arrays.clone(this.publicSeed);
    }

    protected WOTSPlusPrivateKeyParameters getPrivateKey() {
        byte[][] privateKey = new byte[this.params.getLen()][];
        int i = 0;
        while (i < privateKey.length) {
            privateKey[i] = this.expandSecretKeySeed(i);
            ++i;
        }
        return new WOTSPlusPrivateKeyParameters(this.params, privateKey);
    }

    WOTSPlusPublicKeyParameters getPublicKey(OTSHashAddress otsHashAddress) {
        if (otsHashAddress == null) {
            throw new NullPointerException("otsHashAddress == null");
        }
        byte[][] publicKey = new byte[this.params.getLen()][];
        int i = 0;
        while (i < this.params.getLen()) {
            otsHashAddress = (OTSHashAddress)((OTSHashAddress.Builder)((OTSHashAddress.Builder)((OTSHashAddress.Builder)new OTSHashAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(i).withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())).build();
            publicKey[i] = this.chain(this.expandSecretKeySeed(i), 0, this.params.getWinternitzParameter() - 1, otsHashAddress);
            ++i;
        }
        return new WOTSPlusPublicKeyParameters(this.params, publicKey);
    }
}

