package com.aliyun.dashvector.utils;

import com.aliyun.dashvector.common.DashVectorException;
import com.aliyun.dashvector.common.ErrorCode;
import com.aliyun.dashvector.models.*;
import com.aliyun.dashvector.models.Doc;
import com.aliyun.dashvector.models.RequestUsage;
import com.aliyun.dashvector.models.VectorQuery;
import com.aliyun.dashvector.proto.*;
import com.aliyun.dashvector.proto.Vector;
import com.google.common.primitives.Floats;
import com.google.protobuf.ByteString;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
import java.util.stream.Collectors;
import lombok.NonNull;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;

/**
 * @author sanyi
 */
public class Convertor {
  /**
   * Convert doc from model to proto.
   *
   * @param doc            {@link Doc}
   * @param collectionMeta {@link CollectionInfo.DataType}
   * @return {@link com.aliyun.dashvector.proto.Doc}
   */
  public static com.aliyun.dashvector.proto.Doc toDoc(
          @NonNull Doc doc, CollectionMeta collectionMeta) {
    com.aliyun.dashvector.proto.Doc.Builder builder =
        com.aliyun.dashvector.proto.Doc.newBuilder().setScore(doc.getScore());
    // convert dense vector
    if (MapUtils.isNotEmpty(doc.getVectors())) {
      doc.getVectors().forEach((k, v) -> builder.putVectors(k, toVector(v.getValue(), collectionMeta.getDataType(k))));
    } else if (Objects.nonNull(doc.getVector())) {
      builder.setVector(toVector(doc.getVector().getValue(), collectionMeta.getDataType()));
    }

    // convert sparse vector
    if (MapUtils.isNotEmpty(doc.getSparseVectors())) {
      doc.getSparseVectors().forEach(
              (k, v) -> builder.putSparseVectors(k, toDashSparseVector(v)));
    } else if (MapUtils.isNotEmpty(doc.getSparseVector())) {
      builder.putAllSparseVector(toDashVectorSparse(doc.getSparseVector()));
    }

    // convert id
    if (StringUtils.isNotEmpty(doc.getId())) {
      builder.setId(doc.getId());
    }

    // convert doc fields
    if (MapUtils.isNotEmpty(doc.getFields())) {
      builder.putAllFields(toFieldMap(doc.getFields()));
    }
    return builder.build();
  }

  /**
   * Convert doc from proto to model.
   *
   * @param doc {@link com.aliyun.dashvector.proto.Doc}
   * @param collectionMeta {@link CollectionMeta}
   * @return {@link com.aliyun.dashvector.models.Doc}
   */
  public static Doc fromDoc(
          @NonNull com.aliyun.dashvector.proto.Doc doc, @NonNull CollectionMeta collectionMeta) {
    Doc.DocBuilder builder = Doc.builder()
            .id(doc.getId())
            .fields(fromFieldMap(doc.getFieldsMap()))
            .score(doc.getScore());

    // convert dense vector
    if (MapUtils.isNotEmpty(doc.getVectorsMap())) {
      Map<String, com.aliyun.dashvector.models.Vector> vectors = new HashMap<>();
      doc.getVectorsMap().forEach((k, v) -> {
        vectors.put(k, fromVector(v, collectionMeta.getVectors().get(k).getDataType()));
      });
      builder.vectors(vectors);
    } else {
      if (doc.hasVector()) {
        builder.vector(fromVector(doc.getVector(), collectionMeta.getDataType()));
      }
    }

    // convert sparse vector
    if (MapUtils.isNotEmpty(doc.getSparseVectorsMap())) {
      doc.getSparseVectorsMap().forEach(
              (k, v) -> builder.sparseVectors(k, fromDashVectorSparse(v)) );
    } else {
      if (MapUtils.isNotEmpty(doc.getSparseVectorMap())) {
        builder.sparseVector(fromDashVectorSparse(doc.getSparseVectorMap()));
      }
    }

    return builder.build();
  }

  /**
   * Convert vector from model to proto.
   *
   * @param input {@link com.aliyun.dashvector.models.Vector}
   * @param dataType proto.DataType {@link com.aliyun.dashvector.proto.CollectionInfo.DataType}
   * @return {@link Vector}
   */
  public static Vector toVector(
      @NonNull List<? extends Number> input, @NonNull CollectionInfo.DataType dataType) {
    // here is int8
    if (dataType == CollectionInfo.DataType.INT) {
      ByteBuffer buffer = ByteBuffer.allocate(input.size()).order(ByteOrder.LITTLE_ENDIAN);
      input.forEach(e -> buffer.put(e.byteValue()));
      return Vector.newBuilder().setByteVector(ByteString.copyFrom(buffer.array())).build();
    }

    return Vector.newBuilder()
        .setFloatVector(
            Vector.FloatVector.newBuilder().addAllValues(Floats.asList(Floats.toArray(input))))
        .build();
  }

  /**
   * Convert vector from proto to model.
   *
   * @param vector {@link Vector}
   * @param dataType {@link com.aliyun.dashvector.proto.CollectionInfo.DataType}
   * @return {@link com.aliyun.dashvector.models.Vector}
   */
  public static com.aliyun.dashvector.models.Vector fromVector(
      @NonNull Vector vector, @NonNull CollectionInfo.DataType dataType) {
    if (dataType == CollectionInfo.DataType.INT) {
      List<Integer> integerList = new ArrayList<>();
      ByteBuffer buffer =
          ByteBuffer.wrap(vector.getByteVector().toByteArray()).order(ByteOrder.LITTLE_ENDIAN);
      while (buffer.hasRemaining()) {
        integerList.add((int) buffer.get());
      }
      return com.aliyun.dashvector.models.Vector.builder().value(integerList).build();
    } else {
      return com.aliyun.dashvector.models.Vector.builder()
          .value(vector.getFloatVector().getValuesList())
          .build();
    }
  }

  /**
   * Convert FieldMap Value from Object to {@link FieldValue}
   *
   * @param fieldMap the map key is String and value is Object
   * @return map that key is String and value is {@link FieldValue}
   */
  public static Map<String, FieldValue> toFieldMap(@NonNull Map<String, Object> fieldMap) {
    Map<String, FieldValue> field = new HashMap<>();
    for (String key : fieldMap.keySet()) {
      Object value = fieldMap.get(key);
      if (value == null) {
        continue;
      }
      // should support schema-free, so we can't primarily rely on fieldSchema here
      if (value instanceof Integer) {
        field.put(key, FieldValue.newBuilder().setIntValue((int) value).build());
      } else if (value instanceof Float) {
        field.put(key, FieldValue.newBuilder().setFloatValue((float) value).build());
      } else if (value instanceof String) {
        field.put(key, FieldValue.newBuilder().setStringValue((String) value).build());
      } else if (value instanceof Boolean) {
        field.put(key, FieldValue.newBuilder().setBoolValue((boolean) value).build());
      } else if (value instanceof Long) {
        field.put(key, FieldValue.newBuilder().setLongValue((long) value).build());
      } else {
        throw new DashVectorException(
            ErrorCode.INVALID_FIELD.getCode(),
            String.format(
                "DashVectorSDK does not support input field value[%s] and must be in [bool, str, int, float, long]",
                value));
      }
    }
    return field;
  }

  /**
   * Convert FieldMap Value from {@link FieldValue} to Object
   *
   * @param fieldMap the key is String and value is {@link FieldValue}
   * @return map that key is String and value is Object
   */
  public static Map<String, Object> fromFieldMap(@NonNull Map<String, FieldValue> fieldMap) {
    Map<String, Object> field = new HashMap<>();
    for (String key : fieldMap.keySet()) {
      FieldValue value = fieldMap.get(key);
      switch (value.getValueOneofCase()) {
        case INT_VALUE:
          field.put(key, value.getIntValue());
          break;
        case BOOL_VALUE:
          field.put(key, value.getBoolValue());
          break;
        case FLOAT_VALUE:
          field.put(key, value.getFloatValue());
          break;
        case STRING_VALUE:
          field.put(key, value.getStringValue());
          break;
        case LONG_VALUE:
          field.put(key, value.getLongValue());
          break;
        case VALUEONEOF_NOT_SET:
          break;
        default:
          throw new DashVectorException(
              ErrorCode.INVALID_FIELD.getCode(),
              String.format(
                  "DashVectorSDK does not support receive field value[%s] and must be in [bool, str, int, float]",
                  value));
      }
    }
    return field;
  }

  /**
   * Convert FieldSchema Value from {@link FieldType} to Object
   *
   * @param fieldsSchema the key is String and value is {@link FieldType}
   * @return schema that key is String and value is {@link FieldType}
   */
  public static Map<String, FieldType> toFieldsSchema(
      @NonNull Map<String, FieldType> fieldsSchema) {
    Map<String, FieldType> schema = new HashMap<>();
    for (String key : fieldsSchema.keySet()) {
      FieldType fieldType = fieldsSchema.get(key);
      schema.put(key, FieldType.valueOf(fieldType.name()));
    }
    return schema;
  }

  public static Map<Integer, Float> toDashVectorSparse(Map<Long, Float> sparseVector) {
    // 1. sort sparse vector
    Map<Long, Float> sortedMap = new TreeMap<>(sparseVector);
    // 2. convert to Map<Integer, Float>
    return sortedMap.entrySet().stream()
        .collect(Collectors.toMap(entry -> entry.getKey().intValue(), Map.Entry::getValue));
  }

  public static Map<Long, Float> fromDashVectorSparse(Map<Integer, Float> sparseVector) {
    return sparseVector.entrySet().stream()
        .collect(
            Collectors.toMap(entry -> Integer.toUnsignedLong(entry.getKey()), Map.Entry::getValue));
  }

  public static Group fromGroup(
          GroupResult group, CollectionMeta collectionMeta) {
    return Group.builder()
        .groupId(group.getGroupId())
        .docs(
            group.getDocsList().stream()
                .map(doc -> fromDoc(doc, collectionMeta))
                .collect(Collectors.toList()))
        .build();
  }

  public static RequestUsage toRequestUsage(@NonNull com.aliyun.dashvector.proto.RequestUsage usage) {
    RequestUsage.RequestUsageBuilder builder = com.aliyun.dashvector.models.RequestUsage.builder();
    if (usage.hasReadUnits()) {
      return builder.readUnits(usage.getReadUnits()).build();
    } else if (usage.hasWriteUnits()) {
      return builder.writeUnits(usage.getWriteUnits()).build();
    } else {
      throw new DashVectorException(
              ErrorCode.UNKNOWN.getCode(),
              String.format("DashVectorSDK get wrong upper stream usage response, empty read_units and write_units field"));
    }
  }

  // toVectorQuery
  public static com.aliyun.dashvector.proto.VectorQuery toVectorQuery(VectorQuery query, CollectionInfo.DataType dataType) {
    return com.aliyun.dashvector.proto.VectorQuery.newBuilder()
            .setVector(toVector(query.getVector().getValue(),dataType))
            .setParam(toVectorQueryParam(query))
        .build();
  }

  public static com.aliyun.dashvector.proto.VectorQueryParam toVectorQueryParam(VectorQuery param) {
    return com.aliyun.dashvector.proto.VectorQueryParam.newBuilder()
        .setNumCandidates(param.getNumCandidates())
        .setRadius(param.getRadius())
        .setIsLinear(param.isLinear())
        .setEf(param.getEf())
        .build();
  }

  // to VectorParam
  public static com.aliyun.dashvector.proto.CollectionInfo.VectorParam toVectorParam(VectorParam param) {
    return com.aliyun.dashvector.proto.CollectionInfo.VectorParam.newBuilder()
            .setDimension(param.getDimension())
            .setDtype(param.getDataType())
            .setMetric(param.getMetric())
            .setQuantizeType(param.getQuantizeType())
            .build();
  }

  public static VectorParam toVectorParam(CollectionInfo.VectorParam param) {
    return VectorParam.builder()
            .dimension(param.getDimension())
            .dataType(param.getDtype())
            .metric(param.getMetric())
            .quantizeType(param.getQuantizeType())
            .build();
  }

  public static com.aliyun.dashvector.proto.SparseVectorQuery toSparseVectorQuery(
          com.aliyun.dashvector.models.SparseVectorQuery query, CollectionInfo.DataType dataType) {
    return com.aliyun.dashvector.proto.SparseVectorQuery.newBuilder()
            .setSparseVector(toDashSparseVector(query.getVector()))
            .setParam(toSparseVectorQueryParam(query))
            .build();
  }

  private static com.aliyun.dashvector.proto.SparseVector toDashSparseVector(com.aliyun.dashvector.models.SparseVector sparseVector) {
    // 1. sort sparse vector
    Map<Long, Float> sortedMap = new TreeMap<>(sparseVector.getValue());
    // 2. convert to Map<Integer, Float>
    com.aliyun.dashvector.proto.SparseVector.Builder builder = com.aliyun.dashvector.proto.SparseVector.newBuilder();
    builder.putAllSparseVector(sortedMap.entrySet().stream()
            .collect(Collectors.toMap(entry -> entry.getKey().intValue(), Map.Entry::getValue)));
    return builder.build();
  }

  private static com.aliyun.dashvector.models.SparseVector fromDashVectorSparse(com.aliyun.dashvector.proto.SparseVector sparseVector) {
    com.aliyun.dashvector.models.SparseVector.SparseVectorBuilder builder = com.aliyun.dashvector.models.SparseVector.builder();
    Map<Long, Float> longFloatMap = sparseVector.getSparseVectorMap().entrySet().stream()
            .collect(Collectors.toMap(
                    entry -> entry.getKey().longValue(),
                    Map.Entry::getValue
            ));
    return builder.value(longFloatMap).build();
  }

  private static com.aliyun.dashvector.proto.VectorQueryParam toSparseVectorQueryParam(
          com.aliyun.dashvector.models.SparseVectorQuery param) {
    return com.aliyun.dashvector.proto.VectorQueryParam.newBuilder()
            .setNumCandidates(param.getNumCandidates())
            .setRadius(param.getRadius())
            .setIsLinear(param.isLinear())
            .setEf(param.getEf())
            .build();
  }

}
