package com.bxm.adx.common.buy.dispatcher.filter;

import com.bxm.adx.common.adapter.AdxContextFactory;
import com.bxm.adx.common.buy.dispatcher.Dispatcher;
import com.bxm.adx.common.buy.dispatcher.DispatcherContext;
import com.bxm.adx.facade.constant.enums.AdxErrEnum;
import com.bxm.adx.facade.exception.AdxException;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.RandomStringUtils;
import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.context.ApplicationListener;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.util.*;
import java.util.stream.Collectors;

/**
 * @author fgf
 * @date 2023/3/13
 **/
@Slf4j
@Configuration
public class DispatcherFilterFactory implements ApplicationListener<ApplicationReadyEvent> {
    private final Collection<DispatcherFilter> dispatcherFilters = Lists.newArrayList();

    @Override
    public void onApplicationEvent(ApplicationReadyEvent event) {
        ConfigurableApplicationContext context = event.getApplicationContext();
        Map<String, DispatcherFilter> filterMap = context.getBeansOfType(DispatcherFilter.class);
        if (CollectionUtils.isEmpty(filterMap)) {
            return;
        }

        Collection<DispatcherFilter> collection = filterMap.values();
        //先按普通-高级策略排序，再按order排序，高级策略全放在普通后面是因为方便探针的实现，探针只要符合一次条件就可以跳过之后所有的高级策略逻辑
        dispatcherFilters.addAll(
                collection.stream().sorted(
                        new Comparator<DispatcherFilter>() {
                            @Override
                            public int compare(DispatcherFilter o1, DispatcherFilter o2) {
                                DispatcherFilterCondition c1 = getAnnotation(o1);
                                DispatcherFilterCondition c2 = getAnnotation(o2);
                                if (Objects.isNull(c1) && Objects.nonNull(c2)) {
                                    return -1;
                                }
                                if (Objects.nonNull(c1) && Objects.isNull(c2)) {
                                    return 1;
                                }
                                if (Objects.isNull(c1) && Objects.isNull(c2)) {
                                    return 0;
                                }
                                if (c1.probeOn() && !c2.probeOn()) {
                                    return 1;
                                }
                                if (!c1.probeOn() && c2.probeOn()) {
                                    return -1;
                                }
                                return 0;
                            }
                        }.thenComparing(DispatcherFilter::getOrder)
                ).collect(Collectors.toList()));

        if (log.isInfoEnabled()) {
            dispatcherFilters.forEach(dispatcherFilter -> {
                log.info("Registered DispatcherFilter: {} - {}", dispatcherFilter.getOrder(), dispatcherFilter);
            });
        }
    }

    public void filter(DispatcherContext<Dispatcher> context) {
        Set<Dispatcher> trash = Sets.newHashSet();
        Set<Dispatcher> probeTrash = Sets.newHashSet();
        dispatcherFilters.forEach(
                dispatcherFilter -> {
                    String clazzSimpleName = ClassUtils.getUserClass(dispatcherFilter).getSimpleName();
                    DispatcherFilterCondition condition = getAnnotation(dispatcherFilter);
                    if (checkCondition(context, condition)) {
                        if (log.isDebugEnabled()) {
                            log.debug("filter {} ignore", clazzSimpleName);
                        }
                        return;
                    }
                    //过滤
                    dispatcherFilter.filter(context, trash);
                    if (!CollectionUtils.isEmpty(trash)) {
                        if (log.isDebugEnabled()) {
                            log.debug("filter {} trash {}", clazzSimpleName, trash.stream().map(Dispatcher::getId).collect(Collectors.toSet()));
                        }

                        //探针策略过滤的流量分配集合
                        if (Objects.nonNull(condition) && condition.probeOn()) {
                            probeTrash.addAll(trash);
                        }
                        context.getValues().removeIf(
                                dispatcher -> trash.contains(dispatcher)
                        );
                        trash.clear();
                    }
                    if (CollectionUtils.isEmpty(context.getValues())) {
                        return;
                    }
                }
        );
        //探针
        probe(context, probeTrash);

        if (CollectionUtils.isEmpty(context.getValues())) {
            throw new AdxException(AdxErrEnum.DISPATCHER_ERR);
        }
    }

    /**
     * 检查DispatcherFilter的限定条件
     *
     * @param condition
     * @param context
     * @return
     */
    private boolean checkCondition(DispatcherContext<Dispatcher> context, DispatcherFilterCondition condition) {
        Integer dmType = context.getPosition().getDockingMethodType();
        if (Objects.isNull(condition) || Objects.isNull(dmType)) {
            return false;
        }
        int[] dmTypes = condition.onDmType();
        //根据广告位接入的合作类型决定filter是否起效
        if (ArrayUtils.isNotEmpty(dmTypes) && !ArrayUtils.contains(dmTypes, dmType)) {
            return true;
        }
        return false;
    }

    /**
     * 探针
     * @param context
     * @param probeTrash
     */
    private void probe(DispatcherContext<Dispatcher> context, Set<Dispatcher> probeTrash) {
        if (CollectionUtils.isEmpty(probeTrash)) {
            return;
        }

        int ratio = ratio();
        Set<Dispatcher> probeDispatcher = probeTrash.stream()
                .filter(dispatcher -> Objects.nonNull(dispatcher.getProbe()))
                .filter(dispatcher -> {
                    if (log.isDebugEnabled()) {
                        log.debug("Dispatcher {} probe {} user-probe {}", dispatcher.getId(), dispatcher.getProbe(), ratio);
                    }
                   return dispatcher.getProbe().movePointRight(2).intValue() > ratio;
                })
                .collect(Collectors.toSet());
        context.getValues().addAll(probeDispatcher);
    }

    /**
     * 用户分桶系数
     *
     * @return
     */
    private int ratio() {
        String uid = AdxContextFactory.get().getUid();
        uid = StringUtils.isEmpty(uid) ? RandomStringUtils.randomAlphanumeric(8) : uid;
        int ratio = Math.abs(uid.hashCode() % 100);
        return ratio;
    }

    private DispatcherFilterCondition getAnnotation(DispatcherFilter filter) {
        DispatcherFilterCondition condition = AnnotationUtils.findAnnotation(ClassUtils.getUserClass(filter), DispatcherFilterCondition.class);
        return condition;
    }
}
