package com.bxm.warcar.web.util;

import com.bxm.warcar.utils.JsonHelper;
import com.google.common.collect.Sets;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang3.reflect.MethodUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

/**
 * 声明注解 {@link Encrypted} 方法的切面处理器。
 *
 * @author allen
 * @date 2020-10-09
 * @since 1.0
 */
@Aspect
public class EncryptedControllerMethodAspect {

    private static final Logger LOGGER = LoggerFactory.getLogger(EncryptedControllerMethodAspect.class);

    private final Set<Class<?>> primitives = Sets.newHashSet(String.class, Short.class, Byte.class, Integer.class,
            Float.class, Long.class, Boolean.class, Double.class, BigDecimal.class);

    private final EncryptorFactory encryptorFactory;

    public EncryptedControllerMethodAspect(EncryptorFactory encryptorFactory) {
        this.encryptorFactory = encryptorFactory;
    }

    @Pointcut("@annotation(com.bxm.warcar.web.util.Encrypted)")
    public void pointcut() {
    }

    /**
     * <P>环绕通知，对注解 {@link Encrypted} 方法进行切面拦截处理。</P>
     * <p>环绕前，会对方法进行检查，符合以下任意情况，将直接返回。</p>
     * 1、不是{@code Servlet}方法，即无法获取{@code HttpServletRequest}和{@code HttpServletResponse}<br>
     * 2、方法的请求参数长度为 0，并且第一个参数对象类型不是 {@link EncryptedMessage}<br>
     * 3、{@link EncryptedMessage#getMessage()} 长度为0 <br>
     * 4、找不到{@link EncryptedMessage#getCipher()} 对应的{@link Encrypted}接口实现<br>
     * <p>检查通过进入处理环节，根据对应的{@link Encrypted}接口实现对{@link EncryptedMessage#getMessage()}内容进行解密，然后将解密后的数据赋值给{@link EncryptedMessage#setObject(Object)}。
     * 在具体方法实现内只需要使用这个对象来完成具体业务处理就可以了。</p>
     *
     * <p>环绕后，得到方法返回的对象，判断{@link Encrypted#encryptResponseEntity()}来决定是否对返回对象加密。加密的方式是通过{@link Encryptor#encrypt(EncryptContext)}来完成的。</p>
     *
     * @param point 切面
     * @return 返回拦截的结果
     * @throws Throwable 异常
     */
    @Around("pointcut()")
    public Object around(ProceedingJoinPoint point) throws Throwable {
        Object[] args = point.getArgs();
        if (args.length == 0) {
            return point.proceed();
        }

        HttpServletRequest request = WebContextUtils.getRequest();
        HttpServletResponse response = WebContextUtils.getResponse();
        if (Objects.isNull(request) || Objects.isNull(response)) {
            return point.proceed();
        }

        Method method = getMethod(point);
        Encrypted encrypted;
        try {
            encrypted = Objects.requireNonNull(method).getAnnotation(Encrypted.class);
        } catch (NullPointerException e) {
            return point.proceed();
        }

        Object arg = null;
        for (Object o : args) {
            if (o instanceof EncryptedMessage) {
                arg = o;
                break;
            }
        }
        if (Objects.isNull(arg)) {
            if (LOGGER.isWarnEnabled()) {
                LOGGER.warn("Method [{}] Cannot found EncryptedMessage object in arguments list.", method);
            }
            return point.proceed();
        }

        EncryptedMessage<?> message = (EncryptedMessage<?>) arg;
        String msg = message.getMessage();
        if (StringUtils.isBlank(msg)) {
            if (LOGGER.isWarnEnabled()) {
                LOGGER.warn("Method [{}] message is blank", method);
            }
            return point.proceed();
        }

        Integer cipherVersion = message.getCipher();
        if (null == cipherVersion) {
            return point.proceed();
        }
        Encryptor encryptor = encryptorFactory.get(cipherVersion);
        if (Objects.isNull(encryptor)) {
            if (LOGGER.isWarnEnabled()) {
                LOGGER.warn("Method [{}]  unsupported cipher version: {}", method, cipherVersion);
            }
            return point.proceed();
        }

        String key = encryptor.getKey(request);

        // 解密请求
        String body = null;
        try {
            body = encryptor.decrypt(new EncryptContext().setContent(msg).setKey(key));
        } catch (Exception e) {
            if (encrypted.interruptForException()) {
                response.setStatus(HttpStatus.BAD_REQUEST.value());
                return null;
            }
        }
        Class<?> model = encrypted.model();
        TextFormat textFormat = encrypted.plaintextFormat();
        Object object = serializeBody(textFormat, body, model);

        // 解密结果赋值给 object，这里如果注解的序列化类型和泛类型不一致可能会出现 ClassCastException 异常。
        MethodUtils.invokeMethod(message, "setObject", object);

        // 加密响应
        Object proceed = point.proceed();
        if (Objects.isNull(proceed)) {
            return null;
        }
        if (encrypted.encryptResponseEntity()) {
            if (proceed instanceof ResponseEntity) {
                ResponseEntity entity = (ResponseEntity) proceed;
                //是否http状态码是否支持加密
                HttpStatus[] httpStatuses = encrypted.encryptResponseHttpStatus();
                boolean flag = false;
                for (HttpStatus httpStatus : httpStatuses) {
                    if(httpStatus == entity.getStatusCode()){
                        flag = true;
                        break;
                    }
                }
                if (flag) {//支持加密
                    return ResponseEntity.status(entity.getStatusCode()).headers(entity.getHeaders())
                            .body(processResponse(encrypted, encryptor, key, proceed));
                } else {
                    return proceed;
                }
            } else {
                return processResponse(encrypted, encryptor, key, proceed);
            }
        }
        return proceed;
    }

    /**
     * 将 {@code body} 字符串序列化成对象。
     * @param textFormat 文本格式
     * @param body body
     * @param model 对象类型
     * @return 对象实例
     */
    private Object serializeBody(TextFormat textFormat, String body, Class<?> model) {
        if (StringUtils.isNotBlank(body)) {
            if (TextFormat.PARAMS == textFormat) {
                return convertModel(body, model);
            } else {
                return JsonHelper.convert(body, model);
            }
        }
        return null;
    }

    private Serialization newInstance(Class<? extends Serialization> clazz) throws IllegalAccessException, InstantiationException {
        return clazz.newInstance();
    }

    private Object processResponse(Encrypted encrypted, Encryptor encryptor, String key, Object proceed) throws InstantiationException, IllegalAccessException {
        Object res;
        if (proceed instanceof ResponseEntity) {
            res = ((ResponseEntity) proceed).getBody();
        } else {
            res = proceed;
        }
        if (Objects.isNull(res)) {
            return null;
        }
        Class<? extends Serialization> serializationClass = encrypted.serialization();
        Serialization instance = newInstance(serializationClass);
        return encryptor.encrypt(new EncryptContext().setContent(instance.serialize(res)).setKey(key));
    }

    private Method getMethod(ProceedingJoinPoint point) {
        MethodSignature methodSignature = (MethodSignature) point.getSignature();

        Class<?> targetClass = point.getTarget().getClass();
        try {
            return targetClass.getMethod(methodSignature.getName(), methodSignature.getParameterTypes());
        } catch (NoSuchMethodException e) {
            if (LOGGER.isErrorEnabled()) {
                LOGGER.error("getMethod:", e);
            }
            return null;
        }
    }

    @Deprecated
    private boolean isPrimitive(Class<?> clazz) {
        return primitives.stream().anyMatch(clazz::isAssignableFrom);
    }

    /**
     * 拼接参数转换
     */
    public Object convertModel(String body, Class<?> model) {
        Map properties = new HashMap();
        String[] params = body.split("&");
        for (String param : params) {
            //键值对
            String[] kv = param.split("=");
            if(kv.length != 2){
                continue;
            }
            if(kv[0].indexOf(".") > -1){
                String[] keyArr = kv[0].split("\\.");
                Map parent = properties;
                for (int i = 0; i < keyArr.length; i++) {
                    String key = keyArr[i];
                    if(i == keyArr.length - 1){//最后一个key
                        parent.put(key, kv[1]);
                        break;
                    }

                    //参数嵌套
                    Map m = (Map) parent.get(key);
                    if(null == m){
                        m = new HashMap();
                        parent.put(key, m);
                    }
                    parent = m;
                }
            }else{
                properties.put(kv[0], kv[1]);
            }
        }
        return JsonHelper.convert(JsonHelper.convert(properties),model);
    }
}
