/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.normalization;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import lombok.Generated;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.dto.ExplainDTO;
import org.opensearch.neuralsearch.processor.dto.NormalizeScoresDTO;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationUtils;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationUtil;
import org.opensearch.neuralsearch.processor.util.ProcessorUtils;

public class L2ScoreNormalizationTechnique
implements ScoreNormalizationTechnique,
ExplainableTechnique {
    public static final String TECHNIQUE_NAME = "l2";
    private static final float MIN_SCORE = 0.0f;

    public L2ScoreNormalizationTechnique() {
        this(Map.of(), new ScoreNormalizationUtil());
    }

    public L2ScoreNormalizationTechnique(Map<String, Object> params, ScoreNormalizationUtil scoreNormalizationUtil) {
        scoreNormalizationUtil.validateParameters(params, Set.of(), Map.of());
    }

    @Override
    public void normalize(NormalizeScoresDTO normalizeScoresDTO) {
        List<CompoundTopDocs> queryTopDocs = normalizeScoresDTO.getQueryTopDocs();
        List<Float> normsPerSubquery = this.getL2Norm(queryTopDocs);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
                for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
                    scoreDoc.score = this.normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j).floatValue());
                }
            }
        }
    }

    @Override
    public String techniqueName() {
        return TECHNIQUE_NAME;
    }

    @Override
    public String describe() {
        return String.format(Locale.ROOT, "%s", TECHNIQUE_NAME);
    }

    @Override
    public Map<DocIdAtSearchShard, ExplanationDetails> explain(ExplainDTO explainDTO) {
        List<CompoundTopDocs> queryTopDocs = explainDTO.getQueryTopDocs();
        HashMap<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<DocIdAtSearchShard, List<Float>>();
        List<Float> normsPerSubquery = this.getL2Norm(queryTopDocs);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            int numberOfSubQueries = topDocsPerSubQuery.size();
            for (int subQueryIndex = 0; subQueryIndex < numberOfSubQueries; ++subQueryIndex) {
                TopDocs subQueryTopDoc = topDocsPerSubQuery.get(subQueryIndex);
                for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
                    DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard());
                    float normalizedScore = this.normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(subQueryIndex).floatValue());
                    ScoreNormalizationUtil.setNormalizedScore(normalizedScores, docIdAtSearchShard, subQueryIndex, numberOfSubQueries, normalizedScore);
                    scoreDoc.score = normalizedScore;
                }
            }
        }
        return ExplanationUtils.getDocIdAtQueryForNormalization(normalizedScores, this);
    }

    private List<Float> getL2Norm(List<CompoundTopDocs> queryTopDocs) {
        int numOfSubqueries = ProcessorUtils.getNumOfSubqueries(queryTopDocs);
        float[] l2Norms = new float[numOfSubqueries];
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            int bound = topDocsPerSubQuery.size();
            for (int index = 0; index < bound; ++index) {
                for (ScoreDoc scoreDocs : topDocsPerSubQuery.get((int)index).scoreDocs) {
                    int n = index;
                    l2Norms[n] = l2Norms[n] + scoreDocs.score * scoreDocs.score;
                }
            }
        }
        for (int index = 0; index < l2Norms.length; ++index) {
            l2Norms[index] = (float)Math.sqrt(l2Norms[index]);
        }
        ArrayList<Float> l2NormList = new ArrayList<Float>();
        for (int index = 0; index < numOfSubqueries; ++index) {
            l2NormList.add(Float.valueOf(l2Norms[index]));
        }
        return l2NormList;
    }

    private float normalizeSingleScore(float score, float l2Norm) {
        return l2Norm == 0.0f ? 0.0f : score / l2Norm;
    }

    @Generated
    public String toString() {
        return "L2ScoreNormalizationTechnique(TECHNIQUE_NAME=l2)";
    }
}

