/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.sparse.codec;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import lombok.Generated;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.BlockTermState;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.util.BytesRef;
import org.opensearch.neuralsearch.sparse.algorithm.ClusterTrainingExecutor;
import org.opensearch.neuralsearch.sparse.algorithm.seismic.BatchClusteringTask;
import org.opensearch.neuralsearch.sparse.cache.CacheKey;
import org.opensearch.neuralsearch.sparse.codec.ClusteredPostingTermsWriter;
import org.opensearch.neuralsearch.sparse.codec.MergeHelper;
import org.opensearch.neuralsearch.sparse.codec.SparseTermsLuceneWriter;
import org.opensearch.neuralsearch.sparse.common.MergeStateFacade;
import org.opensearch.neuralsearch.sparse.common.PredicateUtils;
import org.opensearch.neuralsearch.sparse.data.PostingClusters;
import org.opensearch.neuralsearch.sparse.mapper.SparseVectorField;

public class SparsePostingsReader {
    @Generated
    private static final Logger log = LogManager.getLogger(SparsePostingsReader.class);
    private final MergeStateFacade mergeStateFacade;
    private final MergeHelper mergeHelper;
    private static final int BATCH_SIZE = 50;

    public void merge(SparseTermsLuceneWriter sparseTermsLuceneWriter, ClusteredPostingTermsWriter clusteredPostingTermsWriter) throws Exception {
        int docCount = 0;
        for (int n : this.mergeStateFacade.getMaxDocs()) {
            docCount += n;
        }
        log.debug("Merge total doc: {}", (Object)docCount);
        ArrayList<FieldInfo> sparseFieldInfos = new ArrayList<FieldInfo>();
        for (FieldInfo fieldInfo : this.mergeStateFacade.getMergeFieldInfos()) {
            if (!SparseVectorField.isSparseField(fieldInfo) || !PredicateUtils.shouldRunSeisPredicate.test(this.mergeStateFacade.getSegmentInfo(), fieldInfo)) continue;
            sparseFieldInfos.add(fieldInfo);
        }
        try {
            sparseTermsLuceneWriter.writeFieldCount(sparseFieldInfos.size());
            for (FieldInfo fieldInfo : sparseFieldInfos) {
                log.debug("Merge field: {}", (Object)fieldInfo.name);
                sparseTermsLuceneWriter.writeFieldNumber(fieldInfo.getFieldNumber());
                CacheKey key = new CacheKey(this.mergeStateFacade.getSegmentInfo(), fieldInfo);
                float clusterRatio = Float.parseFloat((String)fieldInfo.attributes().get("cluster_ratio"));
                int nPostings = Integer.parseInt((String)fieldInfo.attributes().get("n_postings")) == -1 ? Math.max((int)(5.0E-4f * (float)docCount), 160) : Integer.parseInt((String)fieldInfo.attributes().get("n_postings"));
                float summaryPruneRatio = Float.parseFloat((String)fieldInfo.attributes().get("summary_prune_ratio"));
                Set<BytesRef> allTerms = this.mergeHelper.getAllTerms(this.mergeStateFacade, fieldInfo);
                sparseTermsLuceneWriter.writeTermsSize(allTerms.size());
                clusteredPostingTermsWriter.setFieldAndMaxDoc(fieldInfo, docCount, true);
                ArrayList<CompletableFuture<Object>> futures = new ArrayList<CompletableFuture<Object>>(Math.round((float)allTerms.size() / 50.0f));
                int index = 0;
                ArrayList<Object> termBatch = new ArrayList<BytesRef>(50);
                for (BytesRef term : allTerms) {
                    termBatch.add(term);
                    if (termBatch.size() == 50 || index == allTerms.size() - 1) {
                        if (clusterRatio == 0.0f) {
                            futures.add(CompletableFuture.completedFuture(new BatchClusteringTask(termBatch, key, summaryPruneRatio, clusterRatio, nPostings, this.mergeStateFacade, fieldInfo, this.mergeHelper).get()));
                        } else {
                            futures.add(CompletableFuture.supplyAsync(new BatchClusteringTask(termBatch, key, summaryPruneRatio, clusterRatio, nPostings, this.mergeStateFacade, fieldInfo, this.mergeHelper), ClusterTrainingExecutor.getInstance().getExecutor()));
                        }
                        termBatch = new ArrayList(50);
                    }
                    ++index;
                }
                for (int j = 0; j < futures.size(); ++j) {
                    try {
                        List clusters = (List)((CompletableFuture)futures.get(j)).join();
                        futures.set(j, null);
                        for (Pair p : clusters) {
                            BlockTermState state = clusteredPostingTermsWriter.write((BytesRef)p.getLeft(), (PostingClusters)p.getRight());
                            sparseTermsLuceneWriter.writeTerm((BytesRef)p.getLeft(), state);
                        }
                        continue;
                    }
                    catch (CancellationException | CompletionException ex) {
                        log.error("Thread of running clustering from {}th term batch during merge has exception", (Object)j, (Object)ex);
                    }
                }
            }
        }
        catch (IOException ex) {
            clusteredPostingTermsWriter.closeWithException();
            sparseTermsLuceneWriter.closeWithException();
            throw ex;
        }
    }

    @Generated
    public SparsePostingsReader(MergeStateFacade mergeStateFacade, MergeHelper mergeHelper) {
        this.mergeStateFacade = mergeStateFacade;
        this.mergeHelper = mergeHelper;
    }
}

