/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.combination;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;

public class ScoreCombiner {
    @Generated
    private static final Logger log = LogManager.getLogger(ScoreCombiner.class);
    private static final Float ZERO_SCORE = Float.valueOf(0.0f);

    public void combineScores(List<CompoundTopDocs> queryTopDocs, ScoreCombinationTechnique scoreCombinationTechnique) {
        queryTopDocs.forEach(compoundQueryTopDocs -> this.combineShardScores(scoreCombinationTechnique, (CompoundTopDocs)compoundQueryTopDocs));
    }

    private void combineShardScores(ScoreCombinationTechnique scoreCombinationTechnique, CompoundTopDocs compoundQueryTopDocs) {
        if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0L) {
            return;
        }
        List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
        Map<Integer, float[]> normalizedScoresPerDoc = this.getNormalizedScoresPerDocument(topDocsPerSubQuery);
        Map<Integer, Float> combinedNormalizedScoresByDocId = this.combineScoresAndGetCombinedNormalizedScoresPerDocument(normalizedScoresPerDoc, scoreCombinationTechnique);
        List<Integer> sortedDocsIds = this.getSortedDocIds(combinedNormalizedScoresByDocId);
        this.updateQueryTopDocsWithCombinedScores(compoundQueryTopDocs, topDocsPerSubQuery, combinedNormalizedScoresByDocId, sortedDocsIds);
    }

    private List<Integer> getSortedDocIds(Map<Integer, Float> combinedNormalizedScoresByDocId) {
        ArrayList<Integer> sortedDocsIds = new ArrayList<Integer>(combinedNormalizedScoresByDocId.keySet());
        sortedDocsIds.sort((a, b) -> Float.compare(((Float)combinedNormalizedScoresByDocId.get(b)).floatValue(), ((Float)combinedNormalizedScoresByDocId.get(a)).floatValue()));
        return sortedDocsIds;
    }

    private List<ScoreDoc> getCombinedScoreDocs(CompoundTopDocs compoundQueryTopDocs, Map<Integer, Float> combinedNormalizedScoresByDocId, List<Integer> sortedScores, int maxHits) {
        ScoreDoc[] finalScoreDocs = new ScoreDoc[maxHits];
        int shardId = compoundQueryTopDocs.getScoreDocs().get((int)0).shardIndex;
        for (int j = 0; j < maxHits && j < sortedScores.size(); ++j) {
            int docId = sortedScores.get(j);
            finalScoreDocs[j] = new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId).floatValue(), shardId);
        }
        return Arrays.stream(finalScoreDocs).collect(Collectors.toList());
    }

    public Map<Integer, float[]> getNormalizedScoresPerDocument(List<TopDocs> topDocsPerSubQuery) {
        HashMap<Integer, float[]> normalizedScoresPerDoc = new HashMap<Integer, float[]>();
        for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
            TopDocs topDocs = topDocsPerSubQuery.get(j);
            for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
                normalizedScoresPerDoc.computeIfAbsent(scoreDoc.doc, key -> {
                    float[] scores = new float[topDocsPerSubQuery.size()];
                    return scores;
                });
                ((float[])normalizedScoresPerDoc.get((Object)Integer.valueOf((int)scoreDoc.doc)))[j] = scoreDoc.score;
            }
        }
        return normalizedScoresPerDoc;
    }

    private Map<Integer, Float> combineScoresAndGetCombinedNormalizedScoresPerDocument(Map<Integer, float[]> normalizedScoresPerDocument, ScoreCombinationTechnique scoreCombinationTechnique) {
        return normalizedScoresPerDocument.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> Float.valueOf(scoreCombinationTechnique.combine((float[])entry.getValue()))));
    }

    private void updateQueryTopDocsWithCombinedScores(CompoundTopDocs compoundQueryTopDocs, List<TopDocs> topDocsPerSubQuery, Map<Integer, Float> combinedNormalizedScoresByDocId, List<Integer> sortedScores) {
        int maxHits = this.getMaxHits(topDocsPerSubQuery);
        compoundQueryTopDocs.setScoreDocs(this.getCombinedScoreDocs(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sortedScores, maxHits));
        compoundQueryTopDocs.setTotalHits(this.getTotalHits(topDocsPerSubQuery, maxHits));
    }

    protected int getMaxHits(List<TopDocs> topDocsPerSubQuery) {
        int maxHits = 0;
        for (TopDocs topDocs : topDocsPerSubQuery) {
            int hits = topDocs.scoreDocs.length;
            maxHits = Math.max(maxHits, hits);
        }
        return maxHits;
    }

    private TotalHits getTotalHits(List<TopDocs> topDocsPerSubQuery, int maxHits) {
        TotalHits.Relation totalHits = TotalHits.Relation.EQUAL_TO;
        if (topDocsPerSubQuery.stream().anyMatch(topDocs -> topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)) {
            totalHits = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
        }
        return new TotalHits((long)maxHits, totalHits);
    }
}

