/**
 * Copyright (C) 2011-2022 dCache.org <support@dcache.org>
 * 
 * This file is part of xrootd4j.
 * 
 * xrootd4j is free software: you can redistribute it and/or modify it under the terms of the GNU
 * Lesser General Public License as published by the Free Software Foundation, either version 3 of
 * the License, or (at your option) any later version.
 * 
 * xrootd4j is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
 * even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License along with xrootd4j.  If
 * not, see http://www.gnu.org/licenses/.
 */
package org.dcache.xrootd.plugins.authn.gsi;

import static org.dcache.xrootd.plugins.authn.gsi.GSIRequestHandler.RANDOM;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.security.GeneralSecurityException;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PublicKey;
import java.util.Arrays;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.KeyAgreement;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.interfaces.DHPublicKey;
import javax.crypto.spec.DHParameterSpec;
import javax.crypto.spec.DHPublicKeySpec;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import org.bouncycastle.asn1.ASN1InputStream;
import org.bouncycastle.asn1.pkcs.DHParameter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * This class represents a Diffie-Hellman (DH) session. After the DH key agreement
 * has been completed, the resulting session key can be used for (symmetric) encryption/
 * decryption.
 *
 * @author radicke
 * @author tzangerl
 *
 */
public class DHSession {

    private static final Logger LOGGER = LoggerFactory.getLogger(DHSession.class);

    private static final String DH_ALGORITHM_NAME = "DH";
    private static final String DH_HEADER = "-----BEGIN DH PARAMETERS-----";
    private static final String DH_FOOTER = "-----END DH PARAMETERS-----";
    private static final String DH_PUBKEY_HEADER = "---BPUB---";
    private static final String DH_PUBKEY_FOOTER = "---EPUB---";

    // The 512-bit prime being part of the DH parameter set.
    // This specific number set was created by using Openssl and passes
    // its validity tests and is therefore considered to be safe.
    private static final String DH_PRIME =
          ("00:a8:37:9d:6f:ff:e8:63:a0:b1:47:0c:26:dd:1a:"
                + "45:0b:e2:03:9a:f0:83:b1:ba:5b:fa:1d:2f:5b:2a:"
                + "89:08:02:d8:c4:d4:66:8d:14:8d:35:bb:24:b1:af:"
                + "1a:d3:75:c7:c0:3b:61:aa:85:3f:56:69:ae:f2:67:"
                + "da:20:87:5d:93").replaceAll("[:\\s]+", "");

    // the 512 bit DH parameter set used for all DH sessions, consisting
    // of the prime above and the generator value of 2
    // These default values are only used when dCache acts as the server
    static final DHParameterSpec DH_PARAMETERS = new DHParameterSpec(
          new BigInteger(DH_PRIME, 16), BigInteger.valueOf(2));
    private DHParameterSpec _dhParameterSpec;
    private KeyPair _localDHKeyPair;
    private KeyAgreement _keyAgreement;
    private int _sessionIVLen;
    private byte[] IV;
    private boolean paddedKey;

    /**
     * Construct new Diffie-Hellman key exchange session
     * @throws InvalidAlgorithmParameterException Invalid DH parameters (primes)
     * @throws NoSuchAlgorithmException DH algorithm not available in VM
     * @throws InvalidKeyException Private key generated by DH generator invalid
     * @throws NoSuchProviderException Bouncy castle provider does not exist
     */
    public DHSession(boolean isServer, int sessionIVLen)
          throws InvalidAlgorithmParameterException, NoSuchAlgorithmException,
          InvalidKeyException, NoSuchProviderException {
        if (isServer) {
            _dhParameterSpec = DH_PARAMETERS;
            initialize();
        }

        _sessionIVLen = sessionIVLen;

        /*
         *  This should only be true for 4.9+ (10400) or greater, and
         *  (for the client) only if the server says it supports it.
         */
        paddedKey = false;
    }

    private void initialize()
          throws InvalidAlgorithmParameterException, NoSuchAlgorithmException,
          InvalidKeyException, NoSuchProviderException {
        KeyPairGenerator kpairGen =
              KeyPairGenerator.getInstance(DH_ALGORITHM_NAME, "BC");
        kpairGen.initialize(_dhParameterSpec);
        _localDHKeyPair = kpairGen.generateKeyPair();
        _keyAgreement = KeyAgreement.getInstance(DH_ALGORITHM_NAME, "BC");
        _keyAgreement.init(_localDHKeyPair.getPrivate());
    }

    public String getEncodedDHMaterial() throws IOException {
        String dhparams =
              CertUtil.toPEM(toDER(_dhParameterSpec), DH_HEADER, DH_FOOTER);
        DHPublicKey pubkey = (DHPublicKey) _localDHKeyPair.getPublic();

        return dhparams + '\n' + DH_PUBKEY_HEADER +
              pubkey.getY().toString(16) + DH_PUBKEY_FOOTER;
    }

    public void finaliseKeyAgreement(String dhmessage) throws IOException,
          GeneralSecurityException, IllegalStateException {
        int delimitingIndex = dhmessage.indexOf(DH_PUBKEY_HEADER);

        if (delimitingIndex < 0 || delimitingIndex >= dhmessage.length()) {
            throw new IllegalArgumentException("Illegal DH message: "
                  + dhmessage);
        }

        String dhparams = dhmessage.substring(0, delimitingIndex);
        String remotePubKeyString = dhmessage.substring(delimitingIndex);

        DHParameterSpec params = fromDER(CertUtil.fromPEM(dhparams,
              DH_HEADER,
              DH_FOOTER));

        LOGGER.debug("Remote endpoint sent: P = {}, G = {}, L = {},",
              params.getP(), params.getG(), params.getL());

        if (_keyAgreement == null) {
            int l = params.getL();
            /*
             * Note:  setting the value of L to the non-zero bit length of P
             * seems to be required for the key pair generation to work.
             */
            _dhParameterSpec = new DHParameterSpec(params.getP(), params.getG(),
                  l == 0 ? params.getP().bitLength() : l);
            initialize();
        } else if (!(_dhParameterSpec.getP().equals(params.getP())
              && _dhParameterSpec.getG().equals(params.getG()))) {
            throw new GeneralSecurityException(
                  "remote DH parameters differ from local ones");
        }

        removeCharFromString(remotePubKeyString, '\n');

        int envLength = DH_PUBKEY_HEADER.length();
        remotePubKeyString = remotePubKeyString.substring(envLength,
              remotePubKeyString.length() - envLength);

        // parse hex String into a BigInt
        BigInteger remoteY = new BigInteger(remotePubKeyString, 16);

        // convert into a public key
        KeyFactory keyfac = KeyFactory.getInstance(DH_ALGORITHM_NAME, "BC");
        PublicKey remotePubKey = keyfac.generatePublic(new DHPublicKeySpec(
              remoteY, params.getP(), params.getG()));

        // finalise DH key agreement
        _keyAgreement.doPhase(remotePubKey, true);
    }

    public byte[] decrypt(String cipherSpec,
          String keySpec,
          int blocksize,
          byte[] encrypted)
          throws InvalidKeyException,
          IllegalStateException, NoSuchAlgorithmException,
          NoSuchPaddingException, IllegalBlockSizeException,
          BadPaddingException, InvalidAlgorithmParameterException,
          NoSuchProviderException {
        StringBuilder builder = null;

        if (LOGGER.isTraceEnabled()) {
            builder = new StringBuilder();
            GSIBucketUtils.dumpBytes(builder, encrypted);
            LOGGER.trace("encrypted:\n{}", builder.toString());
        }

        encrypted = getIVFromMessagePrefix(encrypted, blocksize);

        byte[] decrypted = translate(cipherSpec,
              keySpec,
              blocksize,
              encrypted,
              Cipher.DECRYPT_MODE);

        if (LOGGER.isTraceEnabled()) {
            builder = new StringBuilder();
            GSIBucketUtils.dumpBytes(builder, decrypted);
            LOGGER.trace("decrypted:\n{}", builder.toString());
        }

        return decrypted;
    }

    public void setPaddedKey(boolean paddedKey) {
        this.paddedKey = paddedKey;
    }

    public void setSessionIVLen(int len) {
        LOGGER.debug("Setting sessionIVLen to {}.", len);
        _sessionIVLen = len;
    }

    public byte[] encrypt(String cipherSpec,
          String keySpec,
          int blocksize,
          byte[] unencrypted)
          throws InvalidKeyException,
          IllegalStateException, NoSuchAlgorithmException,
          NoSuchPaddingException, IllegalBlockSizeException,
          BadPaddingException, InvalidAlgorithmParameterException,
          NoSuchProviderException {
        StringBuilder builder = null;

        if (LOGGER.isTraceEnabled()) {
            builder = new StringBuilder();
            GSIBucketUtils.dumpBytes(builder, unencrypted);
            LOGGER.trace("unencrypted:\n{}", builder.toString());
        }

        refreshIV(blocksize);
        unencrypted = prefixedBuffer(unencrypted);
        byte[] encrypted = translate(cipherSpec,
              keySpec,
              blocksize,
              unencrypted,
              Cipher.ENCRYPT_MODE);

        LOGGER.trace("encrypted:");
        if (LOGGER.isTraceEnabled()) {
            builder.setLength(0);
            GSIBucketUtils.dumpBytes(builder, encrypted);
            LOGGER.trace("encrypted:\n{}", builder.toString());
        }

        return encrypted;
    }

    private byte[] getIVFromMessagePrefix(byte[] encrypted, int blocksize) {
        StringBuilder builder = null;

        if (_sessionIVLen == 0) {
            IV = new byte[blocksize];

            if (LOGGER.isTraceEnabled()) {
                builder = new StringBuilder();
                GSIBucketUtils.dumpBytes(builder, IV);
                LOGGER.trace("initialization vector:\n{}", builder.toString());
            }

            return encrypted;
        }

        IV = Arrays.copyOf(encrypted, _sessionIVLen);

        if (LOGGER.isTraceEnabled()) {
            builder = new StringBuilder();
            GSIBucketUtils.dumpBytes(builder, IV);
            LOGGER.trace("initialization vector:\n{}", builder.toString());
        }

        byte[] extracted = new byte[encrypted.length - _sessionIVLen];
        System.arraycopy(encrypted, _sessionIVLen, extracted,
              0, encrypted.length - _sessionIVLen);

        return extracted;
    }

    private byte[] prefixedBuffer(byte[] in) {
        if (_sessionIVLen == 0) {
            return in;
        }

        byte[] out = new byte[IV.length + in.length];

        System.arraycopy(IV, 0, out, 0, IV.length);
        System.arraycopy(in, 0, out, IV.length, in.length);

        return out;
    }

    private void refreshIV(int blocksize) {
        IV = new byte[blocksize];
        if (_sessionIVLen > 0) {
            for (int i = 0; i < _sessionIVLen; ++i) {
                while (true) {
                    byte next = (byte) RANDOM.nextInt(Byte.MAX_VALUE);
                    if ((next >= '.' && next <= '9') ||
                          (next >= 'A' && next <= 'Z') ||
                          (next >= 'a' && next <= 'z')) {
                        IV[i] = next;
                        break;
                    }
                }
            }
        }
    }

    private byte[] translate(String cipherSpec,
          String keySpec,
          int blocksize,
          byte[] buffer,
          int mode)
          throws InvalidKeyException,
          IllegalStateException, NoSuchAlgorithmException,
          NoSuchPaddingException, IllegalBlockSizeException,
          BadPaddingException, InvalidAlgorithmParameterException,
          NoSuchProviderException {
        IvParameterSpec paramSpec = new IvParameterSpec(IV);
        Cipher cipher = Cipher.getInstance(cipherSpec, "BC");

        byte[] encoded;

        if (paddedKey) {
            LOGGER.info("Using padded DH secret generation.");
            encoded = _keyAgreement.generateSecret();
        } else {
            /**
             * "TlsPremasterSecret" algorithm forces pre 1.50
             * bouncycastle behavior of generation of secret
             * for compatibility with xrootd client
             */
            LOGGER.info("Using unpadded (TlsPremasterSecret) DH secret generation.");
            encoded = _keyAgreement
                  .generateSecret("TlsPremasterSecret")
                  .getEncoded();
        }

        /**
         * Use of the TlsPremasterSecret encoding on encryption can sometimes produce
         * an array where the final 0's have been truncated.  This unfortunately
         * does not play with the key finalization.   Here we simply add
         * back the missing padding.
         *
         * Note:  the non-Tls encoding pads by prepending, not appending; this
         * seems to be unacceptable to servers using the ssl DH_compute_key
         * (unpadded) method.
         */
        if (encoded.length < blocksize && mode == Cipher.ENCRYPT_MODE) {
            byte[] defective = encoded;
            encoded = Arrays.copyOf(defective, blocksize);
            LOGGER.info("Adjusting truncated encoded array by appending 0s.");
            if (LOGGER.isTraceEnabled()) {
                StringBuilder oldB = new StringBuilder();
                StringBuilder newB = new StringBuilder();
                GSIBucketUtils.dumpBytes(oldB, defective);
                GSIBucketUtils.dumpBytes(newB, encoded);
                LOGGER.trace("OLD:\n{}\nNEW:\n{}",
                      oldB.toString(), newB.toString());
            }
        }

        /* need a 128-bit key, that's the way to get it */
        SecretKey sessionKey = new SecretKeySpec(encoded,
              0,
              blocksize,
              keySpec);
        cipher.init(mode, sessionKey, paramSpec);
        return cipher.doFinal(buffer);
    }

    /**
     * remove all occurences of a character
     *
     * @param s
     *            the string
     * @param c
     *            the char to be removed
     * @return the resulting string
     */
    private static String removeCharFromString(String s, char c) {
        return s.replaceAll(String.valueOf(c), "");
    }

    /**
     * Creates an DHParameterSpec object from the DER-encoded byte sequence
     * @param der the DER-encoded byte sequence
     * @return the DHParameterSpec object
     * @throws IOException if the deserialisation goes wrong
     */
    private static DHParameterSpec fromDER(byte[] der) throws IOException {
        ByteArrayInputStream inStream = new ByteArrayInputStream(der);
        ASN1InputStream derInputStream = new ASN1InputStream(inStream);
        DHParameter dhparam = DHParameter.getInstance(derInputStream.readObject());
        return new DHParameterSpec(dhparam.getP(), dhparam.getG());
    }

    /**
     * Creates an DER-encoded byte sequence from the DHParameter object
     * @param paramspec the DH parameter object
     * @return the DER-encoded byte sequence of the DH Parameter object
     */
    private static byte[] toDER(DHParameterSpec paramspec) throws IOException {
        DHParameter derParams = new DHParameter(paramspec.getP(), // Prime
              // (public
              // key)
              paramspec.getG(), // generator
              paramspec.getP().bitLength()); // keylength of Prime

        return derParams.getEncoded("DER");
    }
}
