package com.bxm.adx.common.buy.dispatcher.filter;

import cn.hutool.core.collection.CollUtil;
import com.bxm.adx.common.adapter.AdxContextFactory;
import com.bxm.adx.common.buy.dispatcher.Dispatcher;
import com.bxm.adx.common.buy.dispatcher.DispatcherContext;
import com.bxm.adx.common.buy.dispatcher.DispatcherFlowControl;
import com.bxm.adx.common.caching.Id;
import com.bxm.adx.common.sell.BidRequest;
import com.bxm.warcar.integration.pair.Pair;
import com.bxm.warcar.utils.JsonHelper;
import com.google.common.hash.Hashing;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.RandomStringUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.core.Ordered;

import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * 砍量过滤器
 * 务必放在最后执行
 *
 * @author fgf
 * @date 2023/5/4
 **/
@Slf4j
public class FlowControlFilter implements DispatcherFilter<Dispatcher> {

    private final DispatcherFlowControl dispatcherFlowControl;
    private final Config config;
    private final Pair pair;
    // 砍量权重key
    private static final String KEY = "adx.flow.control.weight";

    public FlowControlFilter(DispatcherFlowControl dispatcherFlowControl, Pair pair) {
        this.dispatcherFlowControl = dispatcherFlowControl;
        this.config = new Config();
        this.pair = pair;
        loadFlowControlWeights();
    }

    private void loadFlowControlWeights() {
        Set<String> flowControlWeights = pair.get(KEY).ofHashSet();
        int flowControl =90;
        try {
            if (CollectionUtils.isNotEmpty(flowControlWeights)) {
                flowControl = Integer.parseInt(CollUtil.getFirst(flowControlWeights));
            }
        } catch (Exception e) {
            log.error("load flow control weights error", e);
        }
        BigDecimal flowControlPercentage = BigDecimal.valueOf(flowControl).divide(BigDecimal.valueOf(100));
        BigDecimal remainingPercentage = BigDecimal.ONE.subtract(flowControlPercentage);
        double[] bucketWeights = new double[]{flowControlPercentage.doubleValue(), remainingPercentage.doubleValue()};
        config.setBucketWeights(bucketWeights);
    }

    @Override
    public void filter(DispatcherContext<Dispatcher> context, Set<Id> trash) {
        Collection<Dispatcher> dispatchers = context.getValues();
        if (CollectionUtils.isEmpty(dispatchers)) {
            return;
        }
        BidRequest request = context.getRequest();

        Set<Dispatcher> removes = dispatchers.stream().filter(dispatcher -> {
            Byte flowSwitch = dispatcher.getChopQuantitySwitch();
            if (Objects.nonNull(flowSwitch) && Dispatcher.FLOW_OPENED_YES == flowSwitch) {
                String bucket = context.getAlgoFlowControlBucket();
                if (StringUtils.isEmpty(bucket)) {
                    String uid = initUid();
                    bucket = config.getBucketIds()[bucket(uid)];
                    context.setAlgoFlowControlBucket(bucket);
                }
                dispatcher.setAlgoFlowControlBucket(bucket);
                //分桶,默认第一个桶走砍量，其他为空白对照组
                if (config.getBucketIds()[0].equals(bucket)) {
                    return dispatcherFlowControl.flowControl(dispatcher, request);
                }
            }
            return false;
        }).collect(Collectors.toSet());
        if (CollectionUtils.isNotEmpty(removes)) {
            trash.addAll(removes);
        }
    }

    @Override
    public int getOrder() {
        return Ordered.LOWEST_PRECEDENCE;
    }

    /**
     * 初始化uid
     * @return
     */
    private String initUid() {
        String uid = AdxContextFactory.get().getUid();
        if (StringUtils.isBlank(uid)) {
            uid = RandomStringUtils.randomAlphanumeric(8);
        }
        return uid;
    }

    /**
     * 确认分桶
     *
     * @param uid
     * @return
     */
    private int bucket(String uid) {
        double[] bucketWeights = config.getBucketWeights();
        int numBuckets = config.getBucketNum();
        int hash = Hashing.murmur3_32().hashString(uid, StandardCharsets.UTF_8).asInt();
        // 取哈希值的绝对值以确保为正数
        hash = Math.abs(hash);

        // 计算总权重
        double totalWeight = 0.0;
        for (double weight : bucketWeights) {
            totalWeight += weight;
        }

        // 计算一个随机值以确定分桶
        double randomValue = (double) hash / Integer.MAX_VALUE * totalWeight;

        // 根据权重确定分桶
        double cumulativeWeight = 0.0;
        for (int i = 0; i < numBuckets; i++) {
            cumulativeWeight += bucketWeights[i];
            if (randomValue < cumulativeWeight) {
                return i;
            }
        }

        // 默认情况下，返回最后一个桶
        return numBuckets - 1;
    }

    /**
     * ab实验配置信息
     */
    @Data
    static class Config {
        /**
         * 实验名
         */
        private String name = "flow-control02";
        /**
         * 分桶数
         */
        private int bucketNum = 2;

        /**
         * 每个桶的占比
         * ex.[0.2,0.3,0.2,0.3]
         */
        private double[] bucketWeights;

        /**
         * 分桶id
         */
        private String[] bucketIds = new String[]{"a2", "b2"};
    }
}
