package com.bxm.component.mybatis.utils;

import com.bxm.newidea.component.tools.SpringContextHolder;
import com.google.common.collect.ImmutableList;
import org.apache.ibatis.session.ExecutorType;
import org.apache.ibatis.session.SqlSession;
import org.mybatis.spring.SqlSessionTemplate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collection;

/**
 * mybatis的批量处理工具
 * @param <M> 执行的Mapper接口
 * @param <E> 需要批量处理的对象
 */
public class MybatisBatchBuilder<M, E> {

    private static final Logger LOGGER = LoggerFactory.getLogger(MybatisBatchBuilder.class);

    private SqlSessionTemplate sqlSessionTemplate;

    private Class<M> mapperClass;

    private Collection<E> data;

    private int total = 0;

    private boolean strict = false;

    private int limit = 500;

    private MybatisBatchBuilder(Class<M> mapperClass, Collection<E> data) {
        this.mapperClass = mapperClass;
        this.data = data;
    }

    private MybatisBatchBuilder(Class<M> mapperClass, E[] data) {
        this.mapperClass = mapperClass;
        this.data = ImmutableList.copyOf(data).asList();
    }

    /**
     * 创建批量处理对象
     * @param mapperClass 执行的Mapper接口class
     * @param data        批量处理的数据
     * @return 批处理构建器
     */
    public static <M, E> MybatisBatchBuilder<M, E> create(Class<M> mapperClass, Collection<E> data) {
        return new MybatisBatchBuilder<>(mapperClass, data);
    }

    /**
     * 创建批量处理对象
     * @param mapperClass 执行的Mapper接口class
     * @param data        批量处理的数据
     * @return 批处理构建器
     */
    public static <M, E> MybatisBatchBuilder<M, E> create(Class<M> mapperClass, E[] data) {
        return new MybatisBatchBuilder<>(mapperClass, data);
    }

    /**
     * 多数据源情况下指定Mapper对应的sessionTempate
     * @param sqlSessionTemplate sqlSessionTemplate
     * @return this
     */
    public MybatisBatchBuilder<M, E> sessionTemplate(SqlSessionTemplate sqlSessionTemplate) {
        this.sqlSessionTemplate = sqlSessionTemplate;
        return this;
    }

    /**
     * 多数据源情况下指定Mapper对应的sessionTempate 对应的spring bean名称
     * @param sqlSessionTempalteName sqlSessionTemplate spring bean name
     * @return this
     */
    public MybatisBatchBuilder<M, E> sessionTemplateName(String sqlSessionTempalteName) {
        this.sqlSessionTemplate = SpringContextHolder.getBean(sqlSessionTempalteName);
        return this;
    }

    /**
     * 是否需要严格判断保存结果
     * @param strict true表示严格模式，严格模式下必须每一条插入都成功
     * @return this
     */
    public MybatisBatchBuilder<M, E> strict(boolean strict) {
        this.strict = strict;
        return this;
    }

    /**
     * 设置批量提交的数量，默认为500
     * @param limit 批量提交的数据
     * @return this
     */
    public MybatisBatchBuilder<M, E> limit(int limit) {
        this.limit = limit;
        return this;
    }

    /**
     * 运行批量插入
     * @param runner mapper执行
     */
    public boolean run(MapperRunner<M, E> runner) {
        if (this.sqlSessionTemplate == null) {
            this.sqlSessionTemplate = SpringContextHolder.getBean(SqlSessionTemplate.class);
        }
        SqlSession session = this.sqlSessionTemplate.getSqlSessionFactory().openSession(ExecutorType.BATCH);
        M mapper = session.getMapper(this.mapperClass);

        try {
            int i = 1;
            for (E element : this.data) {
                this.total += runner.run(mapper, element);
                if (i % this.limit == 0) {
                    session.commit();
                    session.clearCache();
                }
                i++;
            }
        } catch (Exception e) {
            session.rollback();
            LOGGER.error(e.getMessage(), e);
        } finally {
            if (this.data.size() % this.limit != 0) {
                session.commit();
                session.clearCache();
            }
            session.close();
        }

        if (this.strict) {
            return this.data != null && this.total == this.data.size();
        } else {
            return this.total != 0;
        }
    }
}
