package com.aliyun.dashvector;

import com.aliyun.dashvector.common.Constants;
import com.aliyun.dashvector.common.DashVectorException;
import com.aliyun.dashvector.common.ErrorCode;
import com.aliyun.dashvector.models.CollectionMeta;
import com.aliyun.dashvector.models.requests.CreateCollectionRequest;
import com.aliyun.dashvector.models.responses.Response;
import com.aliyun.dashvector.proto.*;
import com.aliyun.dashvector.utils.Utils;
import com.aliyun.dashvector.utils.Validator;
import io.grpc.*;
import io.grpc.stub.MetadataUtils;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import lombok.Getter;
import lombok.NonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @author hq90172
 */
public class DashVectorClient {
  private static final Logger logger = LoggerFactory.getLogger(DashVectorClient.class);
  @Getter @NonNull private final DashVectorClientConfig config;
  private final DashVectorServiceGrpc.DashVectorServiceFutureStub stub;
  private final ManagedChannel channel;
  private final ConcurrentHashMap<String, DashVectorCollection> cache;

  private static class TimeoutInterceptor implements ClientInterceptor {
    private final long timeoutSeconds;

    TimeoutInterceptor(long timeoutSeconds) {
      this.timeoutSeconds = timeoutSeconds;
    }

    @Override
    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
        MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
      return next.newCall(method, callOptions.withDeadlineAfter(timeoutSeconds, TimeUnit.SECONDS));
    }
  }

  public DashVectorClient(@NonNull String apiKey, @NonNull String endpoint)
      throws DashVectorException {
    this(DashVectorClientConfig.builder().apiKey(apiKey).endpoint(endpoint).build());
  }

  public DashVectorClient(@NonNull DashVectorClientConfig config) throws DashVectorException {
    // init config
    this.config = config;
    // timeout interceptor
    TimeoutInterceptor timeoutInterceptor = new TimeoutInterceptor(config.getTimeout().longValue());
    // endpoint verify
    String endpoint = config.getEndpoint();
    if (!Validator.verifyEndpoint(endpoint)) {
      throw new DashVectorException(ErrorCode.INVALID_ENDPOINT);
    }
    // init stub
    Metadata metadata = new Metadata();
    metadata.put(
        Metadata.Key.of(Constants.HEADER_TOKEN, Metadata.ASCII_STRING_MARSHALLER),
        config.getApiKey());
    metadata.put(
        Metadata.Key.of(Constants.HEADER_USER_AGENT, Metadata.ASCII_STRING_MARSHALLER),
        Utils.getUserAgent());
    ManagedChannelBuilder<?> builder =
        ManagedChannelBuilder.forTarget(endpoint)
            .usePlaintext()
            .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata), timeoutInterceptor)
            .maxInboundMessageSize(Constants.GRPC_MAX_MSG_SIZE);

    if (!Utils.isInSecureMode()) {
      builder.useTransportSecurity();
    }
    this.channel = builder.keepAliveTime(1, TimeUnit.MINUTES).build();
    this.stub = DashVectorServiceGrpc.newFutureStub(builder.build());

    // init cache
    this.cache = new ConcurrentHashMap<>();

    try {
      // check version
      GetVersionResponse request =
          this.stub.getVersion(GetVersionRequest.getDefaultInstance()).get();

      if (request.getCode() != ErrorCode.SUCCESS.getCode()) {
        throw new DashVectorException(
            request.getCode(),
            String.format(
                "DashVectorSDK client init failed. Reason: %s. RequestId: %s.",
                request.getMessage(), request.getRequestId()));
      }
      logger.info("DashVector service version is " + request.getVersion());
    } catch (DashVectorException e) {
      throw e;
    } catch (StatusRuntimeException e) {
      throw new DashVectorException(e, e.getStatus().getCode().value());
    } catch (Exception e) {
      throw new DashVectorException(e, ErrorCode.UNKNOWN.getCode());
    }
  }

  public Response<Void> create(@NonNull String name, int dimension) {
    CreateCollectionRequest request =
        CreateCollectionRequest.builder().name(name).dimension(dimension).build();
    return this.create(request);
  }

  public Response<Void> create(@NonNull CreateCollectionRequest request) {

    Integer timeout = request.getTimeout();
    CreateCollectionResponse createCollectionResponse;
    try {
      createCollectionResponse = this.stub.createCollection(request.toProto()).get();
      if (createCollectionResponse.getCode() != ErrorCode.SUCCESS.getCode()
          && Objects.isNull(timeout)) {
        return Response.create(
            createCollectionResponse.getCode(),
            createCollectionResponse.getMessage(),
            createCollectionResponse.getRequestId(),
            null);
      }
    } catch (Exception e) {
      return Response.failed(e);
    }

    String collectionName = request.getName();
    int createTimeout = Objects.isNull(timeout) ? 0 : timeout;
    int retryCount = 0;
    while (true) {
      Response<CollectionMeta> describeResponse = this.describe(collectionName);
      if (describeResponse.getCode() == ErrorCode.SUCCESS.getCode()) {
        String status = describeResponse.getOutput().getStatus();
        if ("ERROR".equals(status) || "DROPPING".equals(status)) {
          return Response.create(
              describeResponse.getCode(),
              String.format(
                  "DashVectorSDK DashVectorCollection(%s) unready. Status is %s ",
                  collectionName, status),
              describeResponse.getRequestId(),
              null);
        } else if ("SERVING".equals(status)) {
          return Response.success(
              createCollectionResponse.getCode(),
              createCollectionResponse.getMessage(),
              createCollectionResponse.getRequestId());
        }
      } else {
        retryCount++;
      }
      if (retryCount > 3) {
        return Response.create(
            describeResponse.getCode(),
            String.format(
                "DashVectorSDK Get DashVectorCollection(%s) Status create", collectionName),
            null,
            null);
      }
      try {
        Thread.sleep(5000);
      } catch (Exception e) {
        return Response.failed(e);
      }

      if (Objects.isNull(timeout)) {
        continue;
      }
      createTimeout -= 5;
      if (createTimeout < 0) {
        return Response.create(
            ErrorCode.TIMEOUT.getCode(),
            String.format(
                "DashVectorSDK Get DashVectorCollection(%s) Status Timeout.", collectionName),
            null,
            null);
      }
    }
  }

  public Response<List<String>> list() {
    try {
      ListCollectionsResponse response =
          this.stub.listCollections(ListCollectionsRequest.getDefaultInstance()).get();
      return Response.success(response);
    } catch (Exception e) {
      return Response.failed(e);
    }
  }

  public Response<CollectionMeta> describe(@NonNull String name) {
    try {
      Validator.verifyCollectionName(name);
      DescribeCollectionRequest req = DescribeCollectionRequest.newBuilder().setName(name).build();
      DescribeCollectionResponse response = this.stub.describeCollection(req).get();
      return Response.success(response);
    } catch (Exception e) {
      return Response.failed(e);
    }
  }

  public Response<Void> delete(@NonNull String name) {
    try {
      Validator.verifyCollectionName(name);
      DeleteCollectionRequest request = DeleteCollectionRequest.newBuilder().setName(name).build();
      DeleteCollectionResponse response = this.stub.deleteCollection(request).get();
      if (response.getCode() == ErrorCode.SUCCESS.getCode()) {
        this.cache.remove(name);
      }
      return Response.success(response.getCode(), response.getMessage(), response.getRequestId());
    } catch (Exception e) {
      return Response.failed(e);
    }
  }

  public DashVectorCollection get(@NonNull String name) {
    if (this.cache.containsKey(name)) {
      return this.cache.get(name);
    }

    Response<CollectionMeta> response = this.describe(name);
    DashVectorCollection dashVectorCollection = new DashVectorCollection(response, stub);
    if (dashVectorCollection.isSuccess()) {
      this.cache.put(name, dashVectorCollection);
    }
    return dashVectorCollection;
  }

  public void close() {
    logger.info("DashVectorSDK closed.");
    this.channel.shutdownNow();
  }
}
