/*
 * Decompiled with CFR 0.152.
 */
package cn.com.infosec.jcajce.provider.asymmetric.sm2;

import cn.com.infosec.asn1.ASN1Encodable;
import cn.com.infosec.asn1.ASN1Integer;
import cn.com.infosec.asn1.ASN1OctetString;
import cn.com.infosec.asn1.ASN1Sequence;
import cn.com.infosec.asn1.DEROctetString;
import cn.com.infosec.asn1.DERSequence;
import cn.com.infosec.asn1.x9.X962NamedCurves;
import cn.com.infosec.asn1.x9.X9ECParameters;
import cn.com.infosec.crypto.digests.SM3Digest;
import cn.com.infosec.jcajce.provider.asymmetric.ec.INFOSECECPrivateKey;
import cn.com.infosec.jcajce.provider.asymmetric.ec.INFOSECECPublicKey;
import cn.com.infosec.jcajce.provider.asymmetric.sm2.SM2PrivateKey;
import cn.com.infosec.jcajce.provider.asymmetric.sm2.SM2PublicKey;
import cn.com.infosec.jcajce.provider.asymmetric.util.BaseCipherSpi;
import cn.com.infosec.jcajce.provider.asymmetric.util.SM2Util;
import cn.com.infosec.math.ec.ECCurve;
import cn.com.infosec.math.ec.ECPoint;
import java.io.IOException;
import java.math.BigInteger;
import java.security.AlgorithmParameters;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.SecureRandom;
import java.security.spec.AlgorithmParameterSpec;
import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.ShortBufferException;

public class CipherSpi
extends BaseCipherSpi {
    Key key;
    boolean encrypt;
    static X9ECParameters param = X962NamedCurves.getByName("prime256v1_sm2");
    static ECCurve curve = param.getCurve();
    static ECPoint G = param.getG();
    static BigInteger n = param.getN();

    private static byte[] sm3hash(byte[] d) {
        SM3Digest sm3 = new SM3Digest();
        sm3.update(d, 0, d.length);
        byte[] r = new byte[32];
        sm3.doFinal(r, 0);
        return r;
    }

    @Override
    protected int engineGetBlockSize() {
        return 0;
    }

    @Override
    protected int engineGetKeySize(Key key) {
        return 32;
    }

    @Override
    protected int engineGetOutputSize(int inputLen) {
        return 65 + inputLen + 32;
    }

    @Override
    protected void engineInit(int opmode, Key key, AlgorithmParameterSpec params, SecureRandom random) throws InvalidKeyException, InvalidAlgorithmParameterException {
        this.key = key;
        if (opmode == 1 || opmode == 3) {
            if (key instanceof INFOSECECPublicKey) {
                try {
                    this.key = SM2Util.pack2PublicKey(((INFOSECECPublicKey)key).getQ());
                }
                catch (Exception e) {
                    throw new InvalidKeyException("must be passed public EC key for encryption");
                }
            } else if (key instanceof SM2PublicKey) {
                this.key = key;
            } else {
                throw new InvalidKeyException("must be passed public EC key for encryption");
            }
            this.encrypt = true;
        } else if (opmode == 2 || opmode == 4) {
            if (key instanceof INFOSECECPrivateKey) {
                try {
                    this.key = SM2Util.pack2PrivateKey(((INFOSECECPrivateKey)key).getD());
                }
                catch (Exception e) {
                    throw new InvalidKeyException("must be passed private EC key for encryption");
                }
            } else if (key instanceof SM2PrivateKey) {
                this.key = key;
            } else {
                throw new InvalidKeyException("must be passed private EC key for encryption");
            }
            this.encrypt = false;
        } else {
            throw new InvalidKeyException("must be passed EC key");
        }
    }

    @Override
    protected void engineInit(int opmode, Key key, AlgorithmParameters params, SecureRandom random) throws InvalidKeyException, InvalidAlgorithmParameterException {
        AlgorithmParameterSpec paramSpec = null;
        this.engineInit(opmode, key, paramSpec, random);
    }

    @Override
    protected void engineInit(int opmode, Key key, SecureRandom random) throws InvalidKeyException {
        try {
            this.engineInit(opmode, key, (AlgorithmParameterSpec)null, random);
        }
        catch (InvalidAlgorithmParameterException e) {
            throw new InvalidKeyException("Eeeek! " + e.toString(), e);
        }
    }

    public static byte[] combineByteArray(byte[] a, int rlen1, byte[] b, int rlen2) {
        byte[] ca = new byte[rlen1];
        for (int i = 0; i < ca.length; ++i) {
            ca[i] = 0;
        }
        if (a.length == rlen1) {
            System.arraycopy(a, 0, ca, 0, rlen1);
        } else if (a.length > rlen1) {
            int aStart = a.length - rlen1;
            System.arraycopy(a, aStart, ca, 0, rlen1);
        } else {
            int caStart = rlen1 - a.length;
            System.arraycopy(a, 0, ca, caStart, a.length);
        }
        byte[] cb = new byte[rlen2];
        for (int i = 0; i < cb.length; ++i) {
            cb[i] = 0;
        }
        if (b.length == rlen2) {
            System.arraycopy(b, 0, cb, 0, rlen2);
        } else if (b.length > rlen2) {
            int bStart = b.length - rlen2;
            System.arraycopy(b, bStart, cb, 0, rlen2);
        } else {
            int cbStart = rlen2 - b.length;
            System.arraycopy(b, 0, cb, cbStart, b.length);
        }
        byte[] c = new byte[rlen1 + rlen2];
        System.arraycopy(ca, 0, c, 0, rlen1);
        System.arraycopy(cb, 0, c, rlen1, rlen2);
        return c;
    }

    public static byte[] combineByteArray(byte[] a, byte[] b) {
        byte[] c = new byte[a.length + b.length];
        System.arraycopy(a, 0, c, 0, a.length);
        System.arraycopy(b, 0, c, a.length, b.length);
        return c;
    }

    public static byte[] int2bytes(int num) {
        byte[] bytes = new byte[4];
        for (int i = 0; i < 4; ++i) {
            bytes[3 - i] = (byte)(0xFF & num >> i * 8);
        }
        return bytes;
    }

    private static byte[] KDF(byte[] Z, int klen) {
        int blockSize = 32;
        int outLen = klen % blockSize == 0 ? klen : (klen / blockSize + 1) * blockSize;
        byte[] out = new byte[outLen];
        int ct = 1;
        byte[] hash = null;
        int outOffset = 0;
        for (int i = 0; i < outLen / blockSize; ++i) {
            byte[] bsct = CipherSpi.int2bytes(ct);
            byte[] buf = new byte[Z.length + bsct.length];
            System.arraycopy(Z, 0, buf, 0, Z.length);
            System.arraycopy(bsct, 0, buf, Z.length, bsct.length);
            hash = CipherSpi.sm3hash(buf);
            System.arraycopy(hash, 0, out, outOffset, hash.length);
            outOffset += hash.length;
            ++ct;
        }
        return out;
    }

    private static byte[] enc(SM2PublicKey pub, byte[] plain) {
        BigInteger k;
        byte[] x = null;
        byte[] y = null;
        byte[] t = null;
        ECPoint C1 = null;
        byte[] C3 = null;
        do {
            k = new BigInteger(256, new SecureRandom());
            C1 = G.multiply(k).normalize();
            BigInteger h = pub.getParameters().getH();
            ECPoint S = pub.getQ().multiply(h).normalize();
            if (S.equals(BigInteger.ZERO)) {
                throw new RuntimeException("calc h multiply Q is zero");
            }
            ECPoint Q2 = pub.getQ().multiply(k).normalize();
            x = Q2.getAffineXCoord().toBigInteger().toByteArray();
            y = Q2.getAffineYCoord().toBigInteger().toByteArray();
            byte[] xy = null;
            try {
                xy = CipherSpi.combineByteArray(x, 32, y, 32);
            }
            catch (Throwable ex) {
                System.out.println("combilexy32 x.len=" + x.length + " y.len=" + y.length);
                throw new RuntimeException(ex);
            }
            t = CipherSpi.KDF(xy, plain.length);
        } while (k.equals(BigInteger.ZERO) || new BigInteger(1, t).equals(BigInteger.ZERO));
        byte[] C2 = new byte[plain.length];
        for (int i = 0; i < C2.length; ++i) {
            C2[i] = (byte)(plain[i] ^ t[i]);
        }
        byte[] tbh1 = null;
        try {
            tbh1 = CipherSpi.combineByteArray(x, 32, plain, plain.length);
        }
        catch (Throwable ex) {
            System.out.println("combiletbh1 x.len=" + x.length + " plain.len=" + plain.length);
            throw new RuntimeException(ex);
        }
        byte[] tbh = null;
        try {
            tbh = CipherSpi.combineByteArray(tbh1, tbh1.length, y, 32);
        }
        catch (Throwable ex) {
            System.out.println("combiletbh2 tbh1.len=" + tbh1.length + " y.len=" + y.length);
            throw new RuntimeException(ex);
        }
        C3 = CipherSpi.sm3hash(tbh);
        ASN1Integer ax = new ASN1Integer(C1.getAffineXCoord().toBigInteger());
        ASN1Integer ay = new ASN1Integer(C1.getAffineYCoord().toBigInteger());
        DEROctetString ah = new DEROctetString(C3);
        DEROctetString ac = new DEROctetString(C2);
        DERSequence aseq = new DERSequence(new ASN1Encodable[]{ax, ay, ah, ac});
        try {
            return aseq.getEncoded();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private static byte[] dec(SM2PrivateKey pri, byte[] encData) {
        BigInteger biC3;
        byte[] y;
        ASN1Sequence aseq = null;
        try {
            aseq = (ASN1Sequence)ASN1Sequence.fromByteArray(encData);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        ASN1Integer ax = (ASN1Integer)aseq.getObjectAt(0);
        ASN1Integer ay = (ASN1Integer)aseq.getObjectAt(1);
        ASN1OctetString ah = (ASN1OctetString)aseq.getObjectAt(2);
        ASN1OctetString ac = (ASN1OctetString)aseq.getObjectAt(3);
        ECPoint Q1 = curve.createPoint(ax.getPositiveValue(), ay.getPositiveValue());
        byte[] C2 = ac.getOctets();
        byte[] C3 = ah.getOctets();
        BigInteger h = pri.getParameters().getH();
        ECPoint S = Q1.multiply(h).normalize();
        if (S.equals(BigInteger.ZERO)) {
            throw new RuntimeException(" C1 multiply h is zero");
        }
        BigInteger d = pri.getD();
        ECPoint dC1 = Q1.multiply(d).normalize();
        byte[] x = dC1.getAffineXCoord().toBigInteger().toByteArray();
        byte[] xy = CipherSpi.combineByteArray(x, 32, y = dC1.getAffineYCoord().toBigInteger().toByteArray(), 32);
        byte[] t = CipherSpi.KDF(xy, C2.length);
        if (new BigInteger(t).equals(BigInteger.ZERO)) {
            throw new RuntimeException("KDF return  zero");
        }
        byte[] M = new byte[C2.length];
        for (int i = 0; i < M.length; ++i) {
            M[i] = (byte)((C2[i] ^ t[i]) & 0xFF);
        }
        byte[] xM = CipherSpi.combineByteArray(x, 32, M, M.length);
        byte[] xMy = CipherSpi.combineByteArray(xM, xM.length, y, 32);
        byte[] hash = CipherSpi.sm3hash(xMy);
        BigInteger biM = new BigInteger(1, hash);
        if (!biM.equals(biC3 = new BigInteger(1, C3))) {
            throw new RuntimeException("decrypt reult not match C3");
        }
        return M;
    }

    @Override
    protected byte[] engineDoFinal(byte[] input, int inputOffset, int inputLen) throws IllegalBlockSizeException, BadPaddingException {
        if (this.encrypt) {
            byte[] M = new byte[inputLen];
            System.arraycopy(input, inputOffset, M, 0, inputLen);
            return CipherSpi.enc((SM2PublicKey)this.key, M);
        }
        byte[] M = new byte[inputLen];
        System.arraycopy(input, inputOffset, M, 0, inputLen);
        return CipherSpi.dec((SM2PrivateKey)this.key, M);
    }

    @Override
    protected int engineDoFinal(byte[] input, int inputOffset, int inputLen, byte[] output, int outputOffset) throws IllegalBlockSizeException, BadPaddingException {
        byte[] out = this.engineDoFinal(input, inputOffset, inputLen);
        System.arraycopy(out, 0, output, outputOffset, out.length);
        return 0;
    }

    @Override
    protected byte[] engineUpdate(byte[] arg0, int arg1, int arg2) {
        return new byte[0];
    }

    @Override
    protected int engineUpdate(byte[] arg0, int arg1, int arg2, byte[] arg3, int arg4) throws ShortBufferException {
        return 0;
    }
}

