package dev.langchain4j.classification;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.RelevanceScore;
import java.lang.Enum;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/* loaded from: input_file:dev/langchain4j/classification/EmbeddingModelTextClassifier.class */
public class EmbeddingModelTextClassifier<E extends Enum<E>> implements TextClassifier<E> {
    private final EmbeddingModel embeddingModel;
    private final Map<E, List<Embedding>> exampleEmbeddingsByLabel;
    private final int maxResults;
    private final double minScore;
    private final double meanToMaxScoreRatio;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dev/langchain4j/classification/EmbeddingModelTextClassifier$LabelWithScore.class */
    public class LabelWithScore {
        private final E label;
        private final double score;

        private LabelWithScore(E e, double d) {
            this.label = e;
            this.score = d;
        }
    }

    public EmbeddingModelTextClassifier(EmbeddingModel embeddingModel, Map<E, ? extends Collection<String>> map) {
        this(embeddingModel, map, 1, 0.0d, 0.5d);
    }

    public EmbeddingModelTextClassifier(EmbeddingModel embeddingModel, Map<E, ? extends Collection<String>> map, int i, double d, double d2) {
        this.embeddingModel = (EmbeddingModel) ValidationUtils.ensureNotNull(embeddingModel, "embeddingModel");
        ValidationUtils.ensureNotNull(map, "examplesByLabel");
        this.exampleEmbeddingsByLabel = new HashMap();
        map.forEach((r8, collection) -> {
            this.exampleEmbeddingsByLabel.put(r8, embeddingModel.embedAll((List) collection.stream().map(TextSegment::from).collect(Collectors.toList())).content());
        });
        this.maxResults = ValidationUtils.ensureGreaterThanZero(Integer.valueOf(i), "maxResults");
        this.minScore = ValidationUtils.ensureBetween(Double.valueOf(d), 0.0d, 1.0d, "minScore");
        this.meanToMaxScoreRatio = ValidationUtils.ensureBetween(Double.valueOf(d2), 0.0d, 1.0d, "meanToMaxScoreRatio");
    }

    @Override // dev.langchain4j.classification.TextClassifier
    public List<E> classify(String str) {
        Embedding content = this.embeddingModel.embed(str).content();
        ArrayList arrayList = new ArrayList();
        this.exampleEmbeddingsByLabel.forEach((r14, list) -> {
            double d = 0.0d;
            double d2 = 0.0d;
            Iterator it2 = list.iterator();
            while (it2.hasNext()) {
                double fromCosineSimilarity = RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(content, (Embedding) it2.next()));
                d += fromCosineSimilarity;
                d2 = Math.max(fromCosineSimilarity, d2);
            }
            arrayList.add(new LabelWithScore(r14, aggregatedScore(d / list.size(), d2)));
        });
        return (List) arrayList.stream().filter(labelWithScore -> {
            return labelWithScore.score >= this.minScore;
        }).sorted(Comparator.comparingDouble(labelWithScore2 -> {
            return 1.0d - labelWithScore2.score;
        })).limit(this.maxResults).map(labelWithScore3 -> {
            return labelWithScore3.label;
        }).collect(Collectors.toList());
    }

    private double aggregatedScore(double d, double d2) {
        return (this.meanToMaxScoreRatio * d) + ((1.0d - this.meanToMaxScoreRatio) * d2);
    }
}
