package com.bianxianmao.offlinemodel.alg;

import com.bianxianmao.offlinemodel.alg.util.FMModelUtil;
import com.bianxianmao.offlinemodel.alg.vo.VectorResult;
import com.bianxianmao.offlinemodel.api.PredResultVo;
import com.bianxianmao.offlinemodel.api.dto.AdvertModelEntity;
import com.bianxianmao.offlinemodel.api.enums.SerializerEnum;
import com.bianxianmao.offlinemodel.mllib.model.SparseFMModel;
import org.apache.spark.mllib.linalg.SparseVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.*;

public class FM extends BaseAlgorithm implements Serializable, IAlgorithm {
    private static final Logger logger = LoggerFactory.getLogger(FM.class);
    private FMModelUtil modelUtil = null;

    public FMModelUtil getModelUtil() {
        if (this.modelUtil == null) {
            this.modelUtil = new FMModelUtil();
        }
        return this.modelUtil;
    }

    public FM() {
    }


    public FM(String featureIdxList, String dict, String modelStr, String featureCollectionList, SerializerEnum serializerEnum) {
        setFeatureDict(dict, serializerEnum);
        setModel(modelStr, serializerEnum);
        setFeatureIdxList(featureIdxList, serializerEnum);
        setFeatureCollectionList(featureCollectionList, serializerEnum);
    }

    public FM(AdvertModelEntity entity) {
        SerializerEnum serializerEnum = entity.getSerializerId() == SerializerEnum.KRYO.getIndex() ? SerializerEnum.KRYO : SerializerEnum.JAVA_ORIGINAL;
        setFeatureDict(entity.getFeatureDictStr(), serializerEnum);
        setFeatureIdxList(entity.getFeatureIdxListStr(), serializerEnum);
        setFeatureCollectionList(entity.getFeatureCollectListStr(), serializerEnum);
        setModel(entity.getModelStr(), serializerEnum);


    }

    public void setEntity(AdvertModelEntity entity) {
        SerializerEnum serializerEnum = entity.getSerializerId() == SerializerEnum.KRYO.getIndex() ? SerializerEnum.KRYO : SerializerEnum.JAVA_ORIGINAL;
        setFeatureDict(entity.getFeatureDictStr(), serializerEnum);
        setFeatureIdxList(entity.getFeatureIdxListStr(), serializerEnum);
        setFeatureCollectionList(entity.getFeatureCollectListStr(), serializerEnum);
        setModel(entity.getModelStr(), serializerEnum);
    }


    public void setModel(SparseFMModel model) {
        getModelUtil().setModel(model);
    }

    public void setModel(String modelStr, SerializerEnum serializerEnum) {
        getModelUtil().setModel(modelStr, serializerEnum);
    }

    public String getModelStr(SerializerEnum serializerEnum) {
        return getModelUtil().getModelStr(serializerEnum);
    }


    public Double predict(List<String> categoryList) {
        Double ret = null;

        try {

            VectorResult vr = getDictUtil().oneHotSparseVectorEncode(getFeatureIdxList(), categoryList, getFeatureCollectionList());
            if (vr != null && vr.getVector() != null) {
                ret = getModelUtil().predict(vr.getVector());
            }
        } catch (Exception e) {
            logger.error("predict happend error", e);
        }
        return ret;
    }

    public Double predict(Map<String, String> categoryMap) {
        Double ret = null;
        try {

//            System.out.println("getFeatureCollectionList()"+getFeatureCollectionList());
            VectorResult vr = getDictUtil().oneHotSparseVectorEncodeWithMap(getFeatureIdxList(), categoryMap, getFeatureCollectionList());
            if (vr != null && vr.getVector() != null) {
//                System.out.println("vr.getVector()"+vr.getVector());
                ret = predictWithVector(vr.getVector());
            }
        } catch (Exception e) {
            logger.error("predict happend error", e);
        }

        return ret;
    }



    //多个素材以数组形式给出的预估值
    public List<String> predictWithAssetArray(Map<String, String> categoryMap,String[] assetArray) {

        List<String> assetCtrPair = new ArrayList<String>(); // [素材id， 素材id对应的ctr]
        assetCtrPair.add("null");
        assetCtrPair.add("null");

        try {
            if ( assetArray == null || assetArray.length==0) {
                    Double ctr = predict(categoryMap);
                    assetCtrPair.set(0,"null");
                    assetCtrPair.set(1,ctr.toString());
                }
            else {
                Set<String> assetSet = new HashSet<>(Arrays.asList(assetArray));
                //System.out.println("assetSet:"+assetSet);
                Double maxCtr= -1.0;
                for (String assetId : assetSet) {
                    //System.out.println("assetId:"+assetId);
                    categoryMap.put("f1008",assetId);
                    Double retCtr = predict(categoryMap);
                    //System.out.println("retCtr:" + retCtr);
                    if (retCtr >= maxCtr) {
                        maxCtr = retCtr;
                        assetCtrPair.set(0, assetId);
                        assetCtrPair.set(1, retCtr.toString());
                        //System.out.println("maxCtr:" + maxCtr);
                    }
                }
            }


        } catch (Exception e) {
            logger.error("predict happend error", e);
        }

        return assetCtrPair;
    }




    public PredResultVo predictWithInfo(Map<String, String> categoryMap) {
        PredResultVo ret = new PredResultVo();
        try {

            VectorResult vr = getDictUtil().oneHotSparseVectorEncodeWithMap(getFeatureIdxList(), categoryMap, getFeatureCollectionList());
            if (vr != null && vr.getVector() != null) {
                Double predValue = predictWithVector(vr.getVector());
                ret.setPredValue(predValue);
                ret.setNewFeatureNums(vr.getNewFeatureNums());
                ret.setTotalFeatureNums(vr.getTotalFeatureNums());


            }
        } catch (Exception e) {
            logger.error("predict happend error", e);
        }

        return ret;
    }


    public VectorResult predictWithVectorResult(Map<String, String> categoryMap) {
        VectorResult ret = new VectorResult();
        try {

            VectorResult vr = getDictUtil().oneHotSparseVectorEncodeWithMap(getFeatureIdxList(), categoryMap, getFeatureCollectionList());
            if (vr != null && vr.getVector() != null) {
                Double predValue = predictWithVector(vr.getVector());
                ret.setNewFeatureNums(vr.getNewFeatureNums());
                ret.setTotalFeatureNums(vr.getTotalFeatureNums());



            }
        } catch (Exception e) {
            logger.error("predict happend error", e);
        }

        return ret;
    }



    public Double predictWithVector(SparseVector vector) {
        Double ret = null;
        try {
            if (vector != null) {
//                System.out.println("vector " + vector);
                ret = getModelUtil().predict(vector);
            }
        } catch (Exception e) {
            logger.error("predict happend error", e);
        }

        return ret;
    }



}
