/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.classification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiTerms;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;

public class BM25NBClassifier
implements Classifier<BytesRef> {
    private final IndexReader indexReader;
    private final String[] textFieldNames;
    private final String classFieldName;
    private final Analyzer analyzer;
    private final IndexSearcher indexSearcher;
    private final Query query;

    public BM25NBClassifier(IndexReader indexReader, Analyzer analyzer, Query query, String classFieldName, String ... textFieldNames) {
        this.indexReader = indexReader;
        this.indexSearcher = new IndexSearcher(this.indexReader);
        this.indexSearcher.setSimilarity((Similarity)new BM25Similarity());
        this.textFieldNames = textFieldNames;
        this.classFieldName = classFieldName;
        this.analyzer = analyzer;
        this.query = query;
    }

    @Override
    public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
        return this.assignClassNormalizedList(inputDocument).get(0);
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
        List<ClassificationResult<BytesRef>> assignedClasses = this.assignClassNormalizedList(text);
        Collections.sort(assignedClasses);
        return assignedClasses;
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
        List<ClassificationResult<BytesRef>> assignedClasses = this.assignClassNormalizedList(text);
        Collections.sort(assignedClasses);
        return assignedClasses.subList(0, max);
    }

    private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
        BytesRef next;
        ArrayList<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<ClassificationResult<BytesRef>>();
        Terms classes = MultiTerms.getTerms((IndexReader)this.indexReader, (String)this.classFieldName);
        TermsEnum classesEnum = classes.iterator();
        String[] tokenizedText = this.tokenize(inputDocument);
        while ((next = classesEnum.next()) != null) {
            if (next.length <= 0) continue;
            Term term = new Term(this.classFieldName, next);
            assignedClasses.add(new ClassificationResult<BytesRef>(term.bytes(), this.calculateLogPrior(term) + this.calculateLogLikelihood(tokenizedText, term)));
        }
        return this.normClassificationResults(assignedClasses);
    }

    private ArrayList<ClassificationResult<BytesRef>> normClassificationResults(List<ClassificationResult<BytesRef>> assignedClasses) {
        ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<ClassificationResult<BytesRef>>();
        if (!assignedClasses.isEmpty()) {
            Collections.sort(assignedClasses);
            double smax = assignedClasses.get(0).getScore();
            double sumLog = 0.0;
            for (ClassificationResult<BytesRef> cr : assignedClasses) {
                sumLog += Math.exp(cr.getScore() - smax);
            }
            double loga = smax;
            loga += Math.log(sumLog);
            for (ClassificationResult<BytesRef> cr : assignedClasses) {
                double scoreDiff = cr.getScore() - loga;
                returnList.add(new ClassificationResult<BytesRef>(cr.getAssignedClass(), Math.exp(scoreDiff)));
            }
        }
        return returnList;
    }

    private String[] tokenize(String text) throws IOException {
        LinkedList<String> result = new LinkedList<String>();
        for (String textFieldName : this.textFieldNames) {
            try (TokenStream tokenStream = this.analyzer.tokenStream(textFieldName, text);){
                CharTermAttribute charTermAttribute = (CharTermAttribute)tokenStream.addAttribute(CharTermAttribute.class);
                tokenStream.reset();
                while (tokenStream.incrementToken()) {
                    result.add(charTermAttribute.toString());
                }
                tokenStream.end();
            }
        }
        return result.toArray(new String[result.size()]);
    }

    private double calculateLogLikelihood(String[] tokens, Term term) throws IOException {
        double result = 0.0;
        for (String word : tokens) {
            result += Math.log(this.getTermProbForClass(term, word));
        }
        return result;
    }

    private double getTermProbForClass(Term classTerm, String ... words) throws IOException {
        BooleanQuery.Builder builder = new BooleanQuery.Builder();
        builder.add(new BooleanClause((Query)new TermQuery(classTerm), BooleanClause.Occur.MUST));
        for (String textFieldName : this.textFieldNames) {
            for (String word : words) {
                builder.add(new BooleanClause((Query)new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
            }
        }
        if (this.query != null) {
            builder.add(this.query, BooleanClause.Occur.MUST);
        }
        TopDocs search = this.indexSearcher.search((Query)builder.build(), 1);
        return search.totalHits.value > 0L ? (double)search.scoreDocs[0].score : 1.0;
    }

    private double calculateLogPrior(Term term) throws IOException {
        TermQuery termQuery = new TermQuery(term);
        BooleanQuery.Builder bq = new BooleanQuery.Builder();
        bq.add((Query)termQuery, BooleanClause.Occur.MUST);
        if (this.query != null) {
            bq.add(this.query, BooleanClause.Occur.MUST);
        }
        TopDocs topDocs = this.indexSearcher.search((Query)bq.build(), 1);
        return topDocs.totalHits.value > 0L ? Math.log(topDocs.scoreDocs[0].score) : 0.0;
    }
}

