package com.bxm.adx.common.algorithm;

import com.bxm.adx.common.AlgorithmProperties;
import com.bxm.adx.common.CacheKeys;
import com.bxm.adx.common.utils.DateUtils;
import com.bxm.warcar.cache.KeyGenerator;
import com.bxm.warcar.utils.NamedThreadFactory;
import com.bxm.warcar.xcache.Fetcher;
import com.bxm.warcar.xcache.TargetFactory;
import com.google.common.collect.Maps;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils;
import org.springframework.boot.context.event.ApplicationPreparedEvent;
import org.springframework.context.ApplicationListener;
import org.springframework.context.annotation.Configuration;
import org.springframework.util.CollectionUtils;
import org.tensorflow.SavedModelBundle;

import java.io.IOException;
import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
 * @author fgf
 * @updated 2025-02-25
 */
@Slf4j
@Configuration
public class ModelGenerateBus implements ApplicationListener<ApplicationPreparedEvent> {
    /**
     * 特征 KV形式
     */
    private final Map<String, Map<String, Long>> oneHotMap = Maps.newConcurrentMap();

    /**
     * 加载的模型
     */
    private final Map<String, SavedModelBundle> savedModelBundleMap = Maps.newConcurrentMap();

    /**
     * 模型版本
     */
    private final Map<String, String> modelVersionMap = Maps.newConcurrentMap();

    private final ScheduledExecutorService schedulerClose =
            new ScheduledThreadPoolExecutor(1, new NamedThreadFactory("close-old-models"));
    private final ScheduledExecutorService schedulerLoad =
            new ScheduledThreadPoolExecutor(1, new NamedThreadFactory("load-models"));

    private final Fetcher fetcher;
    private final AlgorithmProperties algorithmProperties;

    public ModelGenerateBus(Fetcher fetcher, AlgorithmProperties algorithmProperties) {
        this.fetcher = fetcher;
        this.algorithmProperties = algorithmProperties;
    }

    @Override
    public void onApplicationEvent(ApplicationPreparedEvent event) {
        //项目启动时加载模型
        schedulerLoad.schedule(new ModelGenerateRum(), 120, TimeUnit.SECONDS);
        //每天定时加载模型
        schedulerLoad.scheduleAtFixedRate(new ModelGenerateRum(), DateUtils.getInitialDelayForDaily(9, 30),
                24 * 60 * 60 * 1000, TimeUnit.MILLISECONDS);
    }

    /**
     * 获取加载的特征名称
     *
     * @param modelId
     * @return
     */
    public Map<String, Long> getOneHotMap(String modelId) {
        if (oneHotMap.isEmpty()) {
            return null;
        }
        return oneHotMap.get(modelId);
    }

    public void setOneHotMap(String modelId, Map<String, Long> map) {
        oneHotMap.put(modelId, map);
    }


    /**
     * 获取加载的model
     *
     * @param modelId
     * @return
     */
    public SavedModelBundle getSavedModelBundleMap(String modelId) {
        if (savedModelBundleMap.isEmpty()) {
            return null;
        }
        return savedModelBundleMap.get(modelId);
    }

    public String getModelVersionMap(String modelId) {
        if (modelVersionMap.isEmpty()) {
            return null;
        }
        return modelVersionMap.get(modelId);
    }

    /**
     * 添加加载的model到map
     *
     * @param modelId
     * @param savedModelBundle
     * @return
     */
    public SavedModelBundle setSavedModelBundleMap(String modelId, SavedModelBundle savedModelBundle) {
        return savedModelBundleMap.put(modelId, savedModelBundle);
    }

    public String setModelVersionMap(String modelId, String version) {
        return modelVersionMap.put(modelId, version);
    }

    class ModelGenerateRum implements Runnable {
        @Override
        public void run() {
            for (AlgorithmModelEnum algorithmModelEnum : AlgorithmModelEnum.values()) {
                try {
                    String modelId = algorithmModelEnum.name();
                    AlgorithmProperties.ModelInfo modelInfo = getModelInfo(modelId);
                    if (modelInfo == null) {
                        log.warn("模型{}未配置", algorithmModelEnum.name());
                        continue;
                    }
                    String version = findVersion(modelId, modelInfo.getPath(), modelInfo.getEnableVersion());
                    if (StringUtils.isEmpty(version)) {
                        log.warn("模型{}未找到可用版本", algorithmModelEnum.name());
                        continue;
                    }

                    //加载模型
                    loadModelByVersion(modelId, modelInfo.getPath(), version);
                } catch (Exception e) {
                    log.error("加载模型{}失败", algorithmModelEnum.name(), e);
                }
            }
            log.info("加载模型完成, 模型:{}, 模型版本:{}, 模型特征:{}", savedModelBundleMap.keySet(), modelVersionMap, oneHotMap.keySet());
        }
    }

    /**
     * 获取模型的配置
     *
     * @param modelId
     * @return
     */
    private AlgorithmProperties.ModelInfo getModelInfo(String modelId) {
        List<AlgorithmProperties.ModelInfo> modelList = algorithmProperties.getModelList();
        if (CollectionUtils.isEmpty(modelList)) {
            return null;
        }
        return modelList.stream().filter(model -> model.getId().equals(modelId)).findFirst().orElse(null);
    }

    /**
     * 找到最新版本，如果有指定版本则找指定版本
     *
     * @param modelId
     * @param path
     * @param enableVersion
     * @return
     */
    private String findVersion(String modelId, String path, String enableVersion) {
        Path p = Paths.get(path);

        try (DirectoryStream<Path> stream = Files.newDirectoryStream(p)) {
            String finalEnableVersion = null;
            //获取所有版本
            for (Path entry : stream) {
                if (Files.isDirectory(entry)) {
                    String version = entry.getFileName().toString();
                    //筛选加载起效的模型版本
                    if (StringUtils.isNotBlank(enableVersion) && enableVersion.equals(version)) {
                        finalEnableVersion = version;
                        break;
                    }
                    if (StringUtils.isEmpty(finalEnableVersion)) {
                        finalEnableVersion = version;
                    } else {
                        Integer fv = Integer.valueOf(finalEnableVersion);
                        Integer v = Integer.valueOf(version);
                        if (v > fv) {
                            finalEnableVersion = version;
                        } else {
                            continue;
                        }
                    }
                }
            }
            return finalEnableVersion;
        } catch (Exception e) {
            log.error("加载模型版本{}失败", modelId, e);
            return null;
        }
    }

    /**
     * 加载指定版本的模型版本
     *
     * @param modelId
     * @param path
     * @param version
     */
    private void loadModelByVersion(String modelId, String path, String version) {
        //加载特征
        Map<String, Long> directOneHot = loadFieldName(CacheKeys.Algorithm.oneHotKey(modelId, version));
        setOneHotMap(modelId, directOneHot);

        SavedModelBundle newModel = loadByModelId(path + "/" + version);
        if (Objects.nonNull(newModel)) {
            SavedModelBundle old = setSavedModelBundleMap(modelId, newModel);
            String oldVersion = setModelVersionMap(modelId, version);
            //关闭旧的模型
            schedulerClose.scheduleWithFixedDelay(() -> {
                if (old != null) {
                    old.session().close();
                    old.close();
                }
            }, 0, 10, TimeUnit.SECONDS);
        }
    }

    /**
     * 从redisKey获取 特征名称
     *
     * @return
     */
    public Map<String, Long> loadFieldName(KeyGenerator keyGenerator) {
        Map<String, Long> map = new HashMap<>();
        String value = fetcher.fetch(new TargetFactory<String>()
                .keyGenerator(keyGenerator)
                .selector(2)
                .cls(String.class)
                .skipNativeCache(false)
                .build());
        if (StringUtils.isEmpty(value)) {
            log.warn("加载特征名称失败{}", keyGenerator.generateKey());
            return null;
        }
        String[] result = value.split(",");
        for (int i = 0; i < result.length; i++) {
            String[] finalResult = result[i].split(":");
            if (finalResult.length == 2) {
                map.put(finalResult[0].replace("'", "").trim(), Long.valueOf(finalResult[1].trim()));
            }
        }
        return map;
    }

    /**
     * 加载模型
     *
     * @param url
     * @return
     */
    public SavedModelBundle loadByModelId(String url) {
        SavedModelBundle savedModelBundle = SavedModelBundle.load(url, "serve");
        return savedModelBundle;
    }
}

