package com.bxm.warcar.web.util.aes;

import javax.crypto.Cipher;
import javax.crypto.spec.SecretKeySpec;
import javax.servlet.http.HttpServletRequest;

import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;

import org.apache.commons.codec.binary.Base64;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.bxm.warcar.web.util.EncryptContext;
import com.bxm.warcar.web.util.Encryptor;

/**
 * @author allen
 * @date 2020-10-10
 * @since 1.0
 */
public class AesEncryptor implements Encryptor {

    private static final Logger LOGGER = LoggerFactory.getLogger(AesEncryptor.class);

    private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8;
    private static final String KEY_ALGORITHM  = "AES";
    private static final String DEFAULT_CIPHER_ALGORITHM = "AES/ECB/PKCS5Padding";

    private static final String SIGN = "sign";

    @Override
    public int getVersion() {
        return 2;
    }

    @Override
    public String getKey(HttpServletRequest request) {
        String sign = request.getHeader(SIGN);
        if(null == sign){
            sign = request.getParameter(SIGN);
        }
        return org.apache.commons.lang.StringUtils.reverse(sign);
    }

    @Override
    public String encrypt(EncryptContext context) {
        String content = context.getContent();
        String key = context.getKey();
        try {
            return encrypt(content, key);
        } catch (Exception e) {
            LOGGER.warn(String.format("encrypt: \n%s\n%s", key, content), e);
            return null;
        }
    }

    @Override
    public String decrypt(EncryptContext context) throws Exception {
        String content = context.getContent();
        String key = context.getKey();
        try {
            return decrypt(content, key);
        } catch (Exception e) {
            LOGGER.warn(String.format("decrypt:\n%s\n%s", key, content), e);
            throw e;
        }
    }

    public static String encrypt(String content, String key) throws Exception {
        if (StringUtils.isBlank(content) || StringUtils.isBlank(key)) {
            return null;
        }
        Cipher cipher = Cipher.getInstance(DEFAULT_CIPHER_ALGORITHM);
        byte[] byteContent = content.getBytes(DEFAULT_CHARSET);
        SecretKeySpec sKey = getSecretKeySpec(key);
        cipher.init(Cipher.ENCRYPT_MODE, sKey);
        String cipherText = parseByte2HexStr(cipher.doFinal(byteContent)).toLowerCase();
        byte[] bytes = Base64.encodeBase64(cipherText.getBytes(DEFAULT_CHARSET));
        return new String(bytes, DEFAULT_CHARSET);
    }

    public static String decrypt(String encryptContent, String key) throws Exception {
        byte[] bytes = Base64.decodeBase64(encryptContent);
        if (bytes == null) {
            return null;
        }
        String base64Text = new String(bytes);
        Cipher cipher = Cipher.getInstance(DEFAULT_CIPHER_ALGORITHM);
        SecretKeySpec sKey = getSecretKeySpec(key);
        cipher.init(Cipher.DECRYPT_MODE, sKey);
        return new String(cipher.doFinal(parseHexStr2Byte(base64Text)));
    }

    /**
     * 获取AES加密key
     */
    private static SecretKeySpec getSecretKeySpec(String key) {
        String md5 = DigestUtils.md5Hex(key).substring(0, 16);
        byte[] bytes = md5.getBytes(DEFAULT_CHARSET);
        return new SecretKeySpec(bytes, KEY_ALGORITHM);
    }

    /**
     * 将 byte 数组转换成 16 进制的字符串
     */
    private static String parseByte2HexStr(byte[] buf) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < buf.length; i++) {
            //将每个字节都转成 16 进制的
            String hex = Integer.toHexString(buf[i] & 0xFF);
            if (hex.length() == 1) {
                //为保证格式统一，用两位 16 进制的表示一个字节
                hex = '0' + hex;
            }
            sb.append(hex);
        }
        return sb.toString();
    }

    /**
     * 将 16 进制的字符串转换为 byte 数组
     */
    private static byte[] parseHexStr2Byte(String hexStr) {
        if (hexStr.length() < 1) {
            return null;
        }
        byte[] result = new byte[hexStr.length() / 2];
        for (int i = 0; i < hexStr.length() / 2; i++) {
            int num = Integer.valueOf(hexStr.substring(i * 2, i * 2 + 2), 16);
            result[i] = (byte) num;
        }
        return result;
    }
}
