package org.springframework.ai.zhipuai;

import io.micrometer.observation.ObservationRegistry;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
import org.springframework.ai.zhipuai.api.ZhiPuApiConstants;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/* loaded from: input_file:org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.class */
public class ZhiPuAiEmbeddingModel extends AbstractEmbeddingModel {
    private static final Logger logger = LoggerFactory.getLogger(ZhiPuAiEmbeddingModel.class);
    private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
    private final ZhiPuAiEmbeddingOptions defaultOptions;
    private final RetryTemplate retryTemplate;
    private final ZhiPuAiApi zhiPuAiApi;
    private final MetadataMode metadataMode;
    private final ObservationRegistry observationRegistry;
    private EmbeddingModelObservationConvention observationConvention;

    public ZhiPuAiEmbeddingModel(ZhiPuAiApi zhiPuAiApi) {
        this(zhiPuAiApi, MetadataMode.EMBED);
    }

    public ZhiPuAiEmbeddingModel(ZhiPuAiApi zhiPuAiApi, MetadataMode metadataMode) {
        this(zhiPuAiApi, metadataMode, ZhiPuAiEmbeddingOptions.builder().model(ZhiPuAiApi.DEFAULT_EMBEDDING_MODEL).build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
    }

    public ZhiPuAiEmbeddingModel(ZhiPuAiApi zhiPuAiApi, MetadataMode metadataMode, ZhiPuAiEmbeddingOptions zhiPuAiEmbeddingOptions) {
        this(zhiPuAiApi, metadataMode, zhiPuAiEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
    }

    public ZhiPuAiEmbeddingModel(ZhiPuAiApi zhiPuAiApi, MetadataMode metadataMode, ZhiPuAiEmbeddingOptions zhiPuAiEmbeddingOptions, RetryTemplate retryTemplate) {
        this(zhiPuAiApi, metadataMode, zhiPuAiEmbeddingOptions, retryTemplate, ObservationRegistry.NOOP);
    }

    public ZhiPuAiEmbeddingModel(ZhiPuAiApi zhiPuAiApi, MetadataMode metadataMode, ZhiPuAiEmbeddingOptions zhiPuAiEmbeddingOptions, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
        this.observationConvention = DEFAULT_OBSERVATION_CONVENTION;
        Assert.notNull(zhiPuAiApi, "ZhiPuAiApi must not be null");
        Assert.notNull(metadataMode, "metadataMode must not be null");
        Assert.notNull(zhiPuAiEmbeddingOptions, "options must not be null");
        Assert.notNull(retryTemplate, "retryTemplate must not be null");
        Assert.notNull(observationRegistry, "observationRegistry must not be null");
        this.zhiPuAiApi = zhiPuAiApi;
        this.metadataMode = metadataMode;
        this.defaultOptions = zhiPuAiEmbeddingOptions;
        this.retryTemplate = retryTemplate;
        this.observationRegistry = observationRegistry;
    }

    public float[] embed(Document document) {
        Assert.notNull(document, "Document must not be null");
        return embed(document.getFormattedContent(this.metadataMode));
    }

    public EmbeddingResponse call(EmbeddingRequest embeddingRequest) {
        Assert.notEmpty(embeddingRequest.getInstructions(), "At least one text is required!");
        if (embeddingRequest.getInstructions().size() != 1) {
            logger.warn("ZhiPu Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
        }
        EmbeddingRequest buildEmbeddingRequest = buildEmbeddingRequest(embeddingRequest);
        EmbeddingModelObservationContext build = EmbeddingModelObservationContext.builder().embeddingRequest(buildEmbeddingRequest).provider(ZhiPuApiConstants.PROVIDER_NAME).build();
        return (EmbeddingResponse) EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> {
            return build;
        }, this.observationRegistry).observe(() -> {
            ArrayList arrayList = new ArrayList();
            ZhiPuAiApi.Usage usage = new ZhiPuAiApi.Usage(0, 0, 0);
            for (String str : embeddingRequest.getInstructions()) {
                ZhiPuAiApi.EmbeddingRequest<String> createEmbeddingRequest = createEmbeddingRequest(str, buildEmbeddingRequest.getOptions());
                ZhiPuAiApi.EmbeddingList embeddingList = (ZhiPuAiApi.EmbeddingList) this.retryTemplate.execute(retryContext -> {
                    return (ZhiPuAiApi.EmbeddingList) this.zhiPuAiApi.embeddings(createEmbeddingRequest).getBody();
                });
                if (embeddingList == null || embeddingList.data() == null || embeddingList.data().isEmpty()) {
                    logger.warn("No embeddings returned for input: {}", str);
                    arrayList.add(new float[0]);
                } else {
                    usage = new ZhiPuAiApi.Usage(Integer.valueOf(usage.completionTokens().intValue() + embeddingList.usage().completionTokens().intValue()), Integer.valueOf(usage.promptTokens().intValue() + embeddingList.usage().promptTokens().intValue()), Integer.valueOf(usage.totalTokens().intValue() + embeddingList.usage().totalTokens().intValue()));
                    arrayList.add(((ZhiPuAiApi.Embedding) embeddingList.data().get(0)).embedding());
                }
            }
            EmbeddingResponseMetadata embeddingResponseMetadata = new EmbeddingResponseMetadata((embeddingRequest.getOptions() == null || embeddingRequest.getOptions().getModel() == null) ? "unknown" : embeddingRequest.getOptions().getModel(), getDefaultUsage(usage));
            AtomicInteger atomicInteger = new AtomicInteger(0);
            EmbeddingResponse embeddingResponse = new EmbeddingResponse(arrayList.stream().map(fArr -> {
                return new Embedding(fArr, Integer.valueOf(atomicInteger.getAndIncrement()));
            }).toList(), embeddingResponseMetadata);
            build.setResponse(embeddingResponse);
            return embeddingResponse;
        });
    }

    private DefaultUsage getDefaultUsage(ZhiPuAiApi.Usage usage) {
        return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage);
    }

    EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) {
        ZhiPuAiEmbeddingOptions zhiPuAiEmbeddingOptions = null;
        if (embeddingRequest.getOptions() != null) {
            zhiPuAiEmbeddingOptions = (ZhiPuAiEmbeddingOptions) ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class, ZhiPuAiEmbeddingOptions.class);
        }
        ZhiPuAiEmbeddingOptions zhiPuAiEmbeddingOptions2 = (ZhiPuAiEmbeddingOptions) ModelOptionsUtils.merge(zhiPuAiEmbeddingOptions, this.defaultOptions, ZhiPuAiEmbeddingOptions.class);
        if (StringUtils.hasText(zhiPuAiEmbeddingOptions2.getModel())) {
            return new EmbeddingRequest(embeddingRequest.getInstructions(), zhiPuAiEmbeddingOptions2);
        }
        throw new IllegalArgumentException("model cannot be null or empty");
    }

    private ZhiPuAiApi.EmbeddingRequest<String> createEmbeddingRequest(String str, EmbeddingOptions embeddingOptions) {
        return new ZhiPuAiApi.EmbeddingRequest<>(str, embeddingOptions.getModel(), embeddingOptions.getDimensions());
    }

    public void setObservationConvention(EmbeddingModelObservationConvention embeddingModelObservationConvention) {
        this.observationConvention = embeddingModelObservationConvention;
    }
}
