package com.bxm.component.mybatis.utils;

import com.bxm.newidea.component.tools.SpringContextHolder;
import com.google.common.collect.ImmutableList;
import org.apache.ibatis.session.ExecutorType;
import org.apache.ibatis.session.SqlSession;
import org.mybatis.spring.SqlSessionTemplate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

import java.util.Collection;

/**
 * mybatis批量操作帮助类
 * @author liujia 2018/8/10 14:54
 */
public abstract class BatchHelper<T, E> {

    /**
     * Mapper接口
     */
    protected T mapper;

    private Logger logger = LoggerFactory.getLogger(this.getClass());

    /**
     * 批量操作每一次提交的数量
     */
    private short limit = 500;

    private SqlSessionTemplate sqlSessionTemplate;

    private String sqlSessionTemplateName;

    private Collection<E> data;

    private Class<T> mapperClass;

    private int total;

    /**
     * 是否严格校验操作结果
     */
    private boolean strict;

    public BatchHelper(Class<T> mapperClass, Collection<E> data) {
        this.mapperClass = mapperClass;
        this.data = data;
        if (!CollectionUtils.isEmpty(data)) {
            this.doBatch();
        }
    }

    public BatchHelper(Class<T> mapperClass, E[] data) {
        this.mapperClass = mapperClass;
        this.data = ImmutableList.copyOf(data).asList();
        if (!CollectionUtils.isEmpty(this.data)) {
            this.doBatch();
        }
    }

    protected void setLimit(short limit) {
        this.limit = limit;
    }

    protected void setStrict(boolean strict) {
        this.strict = strict;
    }

    /**
     * 批量操作是否成功
     * @return true表示成功
     */
    public final boolean success() {
        if (this.strict) {
            return this.data != null && this.total == this.data.size();
        } else {
            return this.total != 0;
        }
    }

    public void setSqlSessionTemplateName(String sqlSessionTemplateName) {
        this.sqlSessionTemplateName = sqlSessionTemplateName;
    }

    public void setSqlSessionTemplate(SqlSessionTemplate sqlSessionTemplate) {
        this.sqlSessionTemplate = sqlSessionTemplate;
    }

    private void initTemplate() {
        if (this.sqlSessionTemplate == null) {
            if (null != sqlSessionTemplateName) {
                this.sqlSessionTemplate = SpringContextHolder.getBean(sqlSessionTemplateName);
            } else {
                this.sqlSessionTemplate = SpringContextHolder.getBean(SqlSessionTemplate.class);
            }
        }
    }

    private void doBatch() {
        initTemplate();
        SqlSession session = this.sqlSessionTemplate.getSqlSessionFactory().openSession(ExecutorType.BATCH);
        this.mapper = session.getMapper(this.mapperClass);

        try {
            int i = 1;
            for (E element : this.data) {
                this.total += this.invoke(element);
                if (i % this.limit == 0) {
                    session.commit();
                    session.clearCache();
                }
                i++;
            }
        } catch (Exception e) {
            session.rollback();
            this.logger.error(e.getMessage(), e);
        } finally {
            if (this.data.size() % this.limit != 0) {
                session.commit();
                session.clearCache();
            }
            session.close();
        }
    }

    /**
     * 执行数据库具体操作
     * @param element 操作的元素
     * @return 影响行数
     */
    protected abstract int invoke(E element);

}
