package org.bouncycastle.jcajce.provider.kdf.hkdf;

import org.bouncycastle.crypto.Digest;
import org.bouncycastle.crypto.digests.SHA256Digest;
import org.bouncycastle.crypto.digests.SHA384Digest;
import org.bouncycastle.crypto.digests.SHA512Digest;
import org.bouncycastle.crypto.generators.HKDFBytesGenerator;
import org.bouncycastle.crypto.params.HKDFParameters;

import javax.crypto.KDFParameters;
import javax.crypto.KDFSpi;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import javax.crypto.spec.HKDFParameterSpec;
import java.security.InvalidAlgorithmParameterException;
import java.security.NoSuchAlgorithmException;
import java.security.spec.AlgorithmParameterSpec;
import java.util.List;

class HKDFSpi
        extends KDFSpi
{
    protected HKDFBytesGenerator hkdf;

    public HKDFSpi(KDFParameters kdfParameters, Digest digest)
            throws InvalidAlgorithmParameterException
    {
        super(requireNull(kdfParameters, "HKDF" + " does not support parameters"));
        this.hkdf = new HKDFBytesGenerator(digest);
    }

    /**
     * Returns the {@code KDFParameters} used with this {@code KDF} object.
     * <p>
     * The returned parameters may be the same that were used to initialize
     * this {@code KDF} object, or may contain additional default or
     * random parameter values used by the underlying KDF algorithm.
     * If the required parameters were not supplied and can be generated by
     * the {@code KDF} object, the generated parameters are returned;
     * otherwise {@code null} is returned.
     *
     * @return the parameters used with this {@code KDF} object, or
     * {@code null}
     */
    @Override
    protected KDFParameters engineGetParameters()
    {
        return null;
    }

    @Override
    protected SecretKey engineDeriveKey(String alg, AlgorithmParameterSpec derivationSpec)
            throws InvalidAlgorithmParameterException, NoSuchAlgorithmException
    {
        byte[] derivedKey = engineDeriveData(derivationSpec);

        return new SecretKeySpec(derivedKey, alg);
    }
    @Override
    protected byte[] engineDeriveData(AlgorithmParameterSpec derivationSpec)
            throws InvalidAlgorithmParameterException
    {
        if (derivationSpec == null
            || !(derivationSpec instanceof org.bouncycastle.jcajce.spec.HKDFParameterSpec
                 || derivationSpec instanceof javax.crypto.spec.HKDFParameterSpec))
        {
            throw new InvalidAlgorithmParameterException("Invalid AlgorithmParameterSpec provided");
        }

        // TODO: deal with the multi ikm/salt thing
        HKDFParameters hkdfParameters = null;
        int derivedDataLength = 0;
        if (derivationSpec instanceof HKDFParameterSpec.ExtractThenExpand)
        {
            HKDFParameterSpec.ExtractThenExpand spec = (HKDFParameterSpec.ExtractThenExpand)derivationSpec;

            List<SecretKey> ikms = spec.ikms();
            List<SecretKey> salts = spec.salts();

            hkdfParameters = new HKDFParameters(ikms.get(0).getEncoded(), salts.get(0).getEncoded(), spec.info());
            derivedDataLength = spec.length();

            hkdf.init(hkdfParameters);

            byte[] derivedData = new byte[derivedDataLength];
            hkdf.generateBytes(derivedData, 0, derivedDataLength);

            return derivedData;
        }
        else if (derivationSpec instanceof HKDFParameterSpec.Extract)
        {
            HKDFParameterSpec.Extract spec = (HKDFParameterSpec.Extract)derivationSpec;

            List<SecretKey> ikms = spec.ikms();
            List<SecretKey> salts = spec.salts();

            return hkdf.extractPRK(salts.get(0).getEncoded(), ikms.get(0).getEncoded());
        }
        else if (derivationSpec instanceof org.bouncycastle.jcajce.spec.HKDFParameterSpec)
        {
            org.bouncycastle.jcajce.spec.HKDFParameterSpec spec = (org.bouncycastle.jcajce.spec.HKDFParameterSpec)derivationSpec;

            hkdfParameters = new HKDFParameters(spec.getIKM(), spec.getSalt(), spec.getInfo());
            derivedDataLength = spec.getOutputLength();

            hkdf.init(hkdfParameters);

            byte[] derivedData = new byte[derivedDataLength];
            hkdf.generateBytes(derivedData, 0, derivedDataLength);

            return derivedData;
        }
        else
        {
            throw new InvalidAlgorithmParameterException("invalid HKDFParameterSpec provided");
        }
    }

    private static KDFParameters requireNull(KDFParameters kdfParameters,
                                             String message) throws InvalidAlgorithmParameterException
    {
        if (kdfParameters != null)
        {
            throw new InvalidAlgorithmParameterException(message);
        }
        return null;
    }

    public static class HKDFwithSHA256 extends HKDFSpi
    {
        public HKDFwithSHA256(KDFParameters kdfParameters) throws InvalidAlgorithmParameterException
        {
            super(kdfParameters, new SHA256Digest());
        }
        public HKDFwithSHA256() throws InvalidAlgorithmParameterException
        {
            this(null);
        }
    }
    public static class HKDFwithSHA384 extends HKDFSpi
    {
        public HKDFwithSHA384(KDFParameters kdfParameters) throws InvalidAlgorithmParameterException
        {
            super(kdfParameters, new SHA384Digest());
        }
        public HKDFwithSHA384() throws InvalidAlgorithmParameterException
        {
            this(null);
        }
    }
    public static class HKDFwithSHA512 extends HKDFSpi
    {

        public HKDFwithSHA512(KDFParameters kdfParameters) throws InvalidAlgorithmParameterException
        {
            super(kdfParameters, new SHA512Digest());
        }
        public HKDFwithSHA512() throws InvalidAlgorithmParameterException
        {
            this(null);
        }
    }

}
