package com.bxm.lovelink.common.dal.service.impl;

import com.bxm.lovelink.common.dal.entity.FlowControl;
import com.bxm.lovelink.common.dal.service.IFlowControlService;
import com.bxm.lovelink.common.exception.BusinessException;
import com.google.common.collect.Lists;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;

import java.util.List;
import java.util.Objects;

/**
 * 基于redis的lua脚本实现的流量控制
 *
 * @author zhangdong
 * @date 2025/5/17
 */
@Service
public class FlowControlServiceImpl implements IFlowControlService, InitializingBean {

    private final JedisPool jedisPool;
    private String scriptSha;

    public FlowControlServiceImpl(JedisPool jedisPool) {
        this.jedisPool = jedisPool;
    }

    @Override
    public boolean permitPass(FlowControl dto) {
        if (Objects.isNull(dto) || !dto.check()) {
            throw new BusinessException("限流参数异常");
        }
        try (Jedis resource = jedisPool.getResource()) {
            Long result = (Long) resource.evalsha(scriptSha, Lists.newArrayList(dto.getRedisKey()),
                    Lists.newArrayList(String.valueOf(dto.getCountExpireTime()), String.valueOf(dto.getLimitExpireTime()), String.valueOf(dto.getCountThreshold())));
            return Objects.equals(result, 1L);
        }
    }

    @Override
    public boolean permitPass(List<FlowControl> dtoList) {
        boolean result = true;
        for (FlowControl dto : dtoList) {
            result = result && permitPass(dto);
        }
        return result;
    }

    @Override
    public boolean permitPassTruncation(List<FlowControl> dtoList) {
        for (FlowControl dto : dtoList) {
            boolean result = permitPass(dto);
            if (!result) {
                return false;
            }
        }
        return true;
    }

    @Override
    public void afterPropertiesSet() throws Exception {
        String script = "local key = KEYS[1]\n" +
                "local countTime = tonumber(ARGV[1])\n" +
                "local expireTime = tonumber(ARGV[2])\n" +
                "local threshold = tonumber(ARGV[3])+1\n" +
                "\n" +
                "local value = redis.call('INCR', key)\n" +
                "\n" +
                "if value == 1 then\n" +
                "    redis.call('EXPIRE', key, countTime)\n" +
                "end\n" +
                "\n" +
                "if value == threshold then\n" +
                "    if expireTime >= 0 then\n" +
                "        redis.call('EXPIRE', key, expireTime)\n" +
                "    elseif expireTime == -1 then\n" +
                "        redis.call('PERSIST', key)\n" +
                "    end\n" +
                "end\n"+
                "\n" +
                "if value < threshold then\n" +
                "    return 1\n" +
                "else\n" +
                "    return 0\n" +
                "end";
        try (Jedis resource = jedisPool.getResource()) {
            scriptSha = resource.scriptLoad(script);
        }
    }
}
