package com.bxm.newidea.component.filter;

import com.bxm.newidea.component.annotations.FilterBean;
import com.bxm.newidea.component.exception.ExcutorException;
import com.bxm.newidea.component.tools.SpringContextHolder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.core.OrderComparator;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 过滤链执行，会完整的执行过滤链
 *
 * @author liujia
 * @date 1/13/21 6:23 PM
 **/
@Component
@Slf4j
public class FilterChainExecutor implements ApplicationRunner {

    private Map<String, List<IFilter>> groupFilterMap = new HashMap<>();

    private ThreadLocal<List<Class<? extends IFilter>>> choiceFilterThreadLocal = new ThreadLocal<>();

    private ThreadLocal<List<Class<? extends IFilter>>> skipFilterThreadLocal = new ThreadLocal<>();

    /**
     * 选择只执行某些过滤器，不在指定的过滤器列表中，则跳过
     *
     * @param filterClass 需要跳过的过滤器
     * @return 过滤器执行器
     */
    public FilterChainExecutor choice(Class<? extends IFilter> filterClass) {
        return setThreadLocal(filterClass, choiceFilterThreadLocal);
    }

    /**
     * 从执行的过滤器中跳过某些过滤器不执行
     *
     * @param filterClass 需要跳过的过滤器
     * @return 过滤器执行器
     */
    public FilterChainExecutor skip(Class<? extends IFilter> filterClass) {
        return setThreadLocal(filterClass, skipFilterThreadLocal);
    }

    private FilterChainExecutor setThreadLocal(Class<? extends IFilter> filterClass,
                                               ThreadLocal<List<Class<? extends IFilter>>> threadLocal) {
        List<Class<? extends IFilter>> filterList = threadLocal.get();
        if (null == filterList) {
            filterList = new ArrayList<>();
        }
        filterList.add(filterClass);

        threadLocal.set(filterList);
        return this;
    }

    /**
     * 执行指定业务分组的过滤器
     *
     * @param group   业务分组
     * @param context 过滤器执行上下文
     */
    public <T> void doFilter(String group, T context) {
        execFilter(group, context, false);
    }

    /**
     * 并行执行指定业务分组的过滤器
     * 会忽略过滤器的顺序
     *
     * @param group   业务分组
     * @param context 过滤器执行上下文
     */
    public <T> void parallelDoFilter(String group, T context) {
        execFilter(group, context, true);
    }

    private <T> void execFilter(String group, T context, boolean parallel) {
        List<IFilter> filters = groupFilterMap.get(group);

        if (null == filters) {
            return;
        }

        List<Class<? extends IFilter>> choiceFilterList = this.choiceFilterThreadLocal.get();
        List<Class<? extends IFilter>> skipFilterList = skipFilterThreadLocal.get();

        long start = System.currentTimeMillis();

        try {
            if (parallel) {
                filters.stream().parallel().forEach(filter -> {
                    if (filter.supprtParallel()) {
                        internalFilter(filter, context, choiceFilterList, skipFilterList);
                    }
                });

                // 补充执行不支持并行的过滤器
                filters.forEach(filter -> {
                    if (!filter.supprtParallel()) {
                        internalFilter(filter, context, choiceFilterList, skipFilterList);
                    }
                });
            } else {
                for (IFilter filter : filters) {
                    internalFilter(filter, context, choiceFilterList, skipFilterList);
                }
            }
        } catch (Exception e) {
            log.error(e.getMessage(), e);
            throw new ExcutorException("过滤器执行失败,策略分组：" + group, e);
        } finally {
            choiceFilterThreadLocal.remove();
            skipFilterThreadLocal.remove();
        }

        long expense = System.currentTimeMillis() - start;

        if (log.isDebugEnabled()) {
            log.debug("逻辑分组[{}]执行总耗时：{}",
                    group,
                    expense);
        }

        if (expense > 100L) {
            log.info("逻辑分组[{}]执行总耗时：{}，超过100ms，需要重点关注", group, expense);
        }
    }

    @SuppressWarnings("unchecked")
    private <T> void internalFilter(IFilter filter, T context,
                                    List<Class<? extends IFilter>> choiceFilterList,
                                    List<Class<? extends IFilter>> skipFilterList) {
        if (null != skipFilterList && skipFilterList.contains(filter.getClass())) {
            return;
        }

        if (null != choiceFilterList && choiceFilterList.size() > 0) {
            if (!choiceFilterList.contains(filter.getClass())) {
                return;
            }
        }

        long start = System.currentTimeMillis();

        filter.doFilter(context);

        if (log.isDebugEnabled()) {
            log.debug("过滤器[{}]执行耗费：{}ms", filter.getClass().getSimpleName(), System.currentTimeMillis() - start);
        }
    }

    @Override
    public void run(ApplicationArguments args) {
        for (IFilter filter : SpringContextHolder.getBeans(IFilter.class)) {
            FilterBean annotation = AnnotationUtils.findAnnotation(filter.getClass(), FilterBean.class);
            if (null == annotation) {
                log.warn("[{}]未提供FilterBean注解", filter.getClass().getSimpleName());
                continue;
            }

            List<IFilter> filters = groupFilterMap.get(annotation.group());
            if (filters == null) {
                filters = new ArrayList<>();
            }
            filters.add(filter);
            groupFilterMap.put(annotation.group(), filters);
        }

        for (List<IFilter> rules : groupFilterMap.values()) {
            OrderComparator.sort(rules);
        }
    }
}
