package org.nuxeo.ecm.platform.documentcategorization.categorizer.tfidf;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.io.StringReader;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.Token;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.nuxeo.ecm.platform.documentcategorization.service.Categorizer;

/* loaded from: input_file:org/nuxeo/ecm/platform/documentcategorization/categorizer/tfidf/TfIdfCategorizer.class */
public class TfIdfCategorizer extends PrimitiveVectorHelper implements Categorizer, Serializable {
    private static final long serialVersionUID = 1;
    public static final Log log = LogFactory.getLog(TfIdfCategorizer.class);
    protected final Set<String> topicNames;
    protected final Map<String, Object> topicTermCount;
    protected final Map<String, Object> cachedTopicTfIdf;
    protected final Map<String, Float> cachedTopicTfIdfNorm;
    protected long[] allTermCounts;
    protected final int dim;
    protected float[] cachedIdf;
    protected long totalTermCount;
    protected final HashingVectorizer vectorizer;
    protected transient Analyzer analyzer;
    protected Double ratioOverMedian;
    protected boolean updateDisabled;

    public TfIdfCategorizer() {
        this(524288);
    }

    public TfIdfCategorizer(int i) {
        this.topicNames = new TreeSet();
        this.topicTermCount = new ConcurrentHashMap();
        this.cachedTopicTfIdf = new ConcurrentHashMap();
        this.cachedTopicTfIdfNorm = new ConcurrentHashMap();
        this.totalTermCount = 0L;
        this.ratioOverMedian = Double.valueOf(3.0d);
        this.updateDisabled = false;
        this.dim = i;
        this.allTermCounts = new long[i];
        this.vectorizer = new HashingVectorizer().dimension(i);
    }

    public HashingVectorizer getVectorizer() {
        return this.vectorizer;
    }

    public Analyzer getAnalyzer() {
        if (this.analyzer == null) {
            this.analyzer = new StandardAnalyzer();
        }
        return this.analyzer;
    }

    public synchronized void disableUpdate() {
        this.updateDisabled = true;
        getIdf();
        for (String str : this.topicNames) {
            tfidf(str);
            tfidfNorm(str);
        }
        this.topicTermCount.clear();
        this.allTermCounts = null;
    }

    public void update(String str, List<String> list) {
        if (this.updateDisabled) {
            throw new IllegalStateException("updates are no longer authorized once #disableUpdate has been called");
        }
        long[] count = this.vectorizer.count(list);
        this.totalTermCount += sum(count);
        long[] jArr = (long[]) this.topicTermCount.get(str);
        if (jArr == null) {
            jArr = new long[this.dim];
            this.topicTermCount.put(str, jArr);
            this.topicNames.add(str);
        }
        add(jArr, count);
        add(this.allTermCounts, count);
        invalidateCache(str);
    }

    public void update(String str, String str2) {
        update(str, tokenize(str2));
    }

    protected void invalidateCache(String str) {
        this.cachedTopicTfIdf.remove(str);
        this.cachedTopicTfIdfNorm.remove(str);
        this.cachedIdf = null;
    }

    protected void invalidateCache() {
        Iterator<String> it = this.topicNames.iterator();
        while (it.hasNext()) {
            invalidateCache(it.next());
        }
    }

    public Map<String, Float> getSimilarities(List<String> list) {
        TreeMap treeMap = new TreeMap();
        float[] tfIdf = getTfIdf(this.vectorizer.count(list));
        float normOf = normOf(tfIdf);
        if (normOf == 0.0f) {
            return treeMap;
        }
        for (String str : this.topicNames) {
            float[] tfidf = tfidf(str);
            float tfidfNorm = tfidfNorm(str);
            if (tfidfNorm != 0.0f) {
                treeMap.put(str, Float.valueOf(dot(tfIdf, tfidf) / (normOf * tfidfNorm)));
            }
        }
        return sortByDecreasingValue(treeMap);
    }

    public Map<String, Float> getSimilarities(String str) {
        return getSimilarities(tokenize(str));
    }

    protected float tfidfNorm(String str) {
        Float f = this.cachedTopicTfIdfNorm.get(str);
        if (f == null) {
            f = Float.valueOf(normOf(tfidf(str)));
            this.cachedTopicTfIdfNorm.put(str, f);
        }
        return f.floatValue();
    }

    protected float[] tfidf(String str) {
        float[] fArr = (float[]) this.cachedTopicTfIdf.get(str);
        if (fArr == null) {
            fArr = getTfIdf((long[]) this.topicTermCount.get(str));
            this.cachedTopicTfIdf.put(str, fArr);
        }
        return fArr;
    }

    protected float[] getTfIdf(long[] jArr) {
        float[] idf = getIdf();
        float[] fArr = new float[jArr.length];
        long sum = sum(jArr);
        if (sum == 0) {
            return fArr;
        }
        for (int i = 0; i < jArr.length; i++) {
            fArr[i] = (((float) jArr[i]) / ((float) sum)) * idf[i];
        }
        return fArr;
    }

    protected float[] getIdf() {
        if (this.cachedIdf == null) {
            float[] fArr = new float[this.allTermCounts.length];
            for (int i = 0; i < this.allTermCounts.length; i++) {
                if (this.allTermCounts[i] == 0) {
                    fArr[i] = 0.0f;
                } else {
                    fArr[i] = (float) Math.log1p(((float) this.totalTermCount) / ((float) this.allTermCounts[i]));
                }
            }
            this.cachedIdf = fArr;
        }
        return this.cachedIdf;
    }

    public int getDimension() {
        return this.dim;
    }

    public void learnFiles(File file) throws IOException {
        if (!file.isDirectory()) {
            throw new IOException(String.format("%s is not a folder", file.getAbsolutePath()));
        }
        for (File file2 : file.listFiles()) {
            if (!file2.isDirectory()) {
                String name = file2.getName();
                if (name.contains(".")) {
                    name = name.substring(0, name.indexOf(46));
                }
                log.info(String.format("About to analyze file %s", file2.getAbsolutePath()));
                FileInputStream fileInputStream = new FileInputStream(file2);
                try {
                    BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(fileInputStream, Charset.forName("UTF-8")));
                    String readLine = bufferedReader.readLine();
                    int i = 0;
                    while (readLine != null) {
                        update(name, readLine);
                        readLine = bufferedReader.readLine();
                        i++;
                        if (i % 10000 == 0) {
                            log.info(String.format("Analyzed %d lines from '%s'", Integer.valueOf(i), file2.getAbsolutePath()));
                        }
                    }
                } finally {
                    fileInputStream.close();
                }
            }
        }
    }

    public void saveToFile(File file) throws IOException {
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        try {
            saveToStream(fileOutputStream);
            fileOutputStream.close();
        } catch (Throwable th) {
            fileOutputStream.close();
            throw th;
        }
    }

    public void saveToStream(OutputStream outputStream) throws IOException {
        if (this.updateDisabled) {
            throw new IllegalStateException("model in disabled update mode cannot be saved");
        }
        invalidateCache();
        GZIPOutputStream gZIPOutputStream = new GZIPOutputStream(outputStream);
        new ObjectOutputStream(gZIPOutputStream).writeObject(this);
        gZIPOutputStream.finish();
    }

    public static TfIdfCategorizer load(InputStream inputStream) throws IOException, ClassNotFoundException {
        return (TfIdfCategorizer) new ObjectInputStream(new GZIPInputStream(inputStream)).readObject();
    }

    public static TfIdfCategorizer load(String str) throws IOException, ClassNotFoundException {
        return load(Thread.currentThread().getContextClassLoader().getResourceAsStream(str));
    }

    public static void main(String[] strArr) throws FileNotFoundException, IOException, ClassNotFoundException {
        TfIdfCategorizer tfIdfCategorizer;
        if (strArr.length < 2 || strArr.length > 3) {
            System.out.println("Train a model:\nFirst argument is the model filename (e.g. my-model.gz)\nSecond argument is the path to a folder with UTF-8 text files\nThird optional argument is the dimension of the model");
            System.exit(0);
        }
        File file = new File(strArr[0]);
        if (file.exists()) {
            log.info("Loading model from: " + file.getAbsolutePath());
            FileInputStream fileInputStream = new FileInputStream(file);
            try {
                tfIdfCategorizer = load(fileInputStream);
                fileInputStream.close();
            } catch (Throwable th) {
                fileInputStream.close();
                throw th;
            }
        } else {
            tfIdfCategorizer = strArr.length == 3 ? new TfIdfCategorizer(Integer.valueOf(strArr[2]).intValue()) : new TfIdfCategorizer();
            log.info("Initializing new model with dimension: " + tfIdfCategorizer.getDimension());
        }
        tfIdfCategorizer.learnFiles(new File(strArr[1]));
        log.info("Saving trained model to: " + file.getAbsolutePath());
        tfIdfCategorizer.saveToFile(file);
    }

    @Override // org.nuxeo.ecm.platform.documentcategorization.service.Categorizer
    public List<String> guessCategories(String str, int i) {
        return guessCategories(str, i, null);
    }

    @Override // org.nuxeo.ecm.platform.documentcategorization.service.Categorizer
    public List<String> guessCategories(String str, int i, Double d) {
        Double d2 = d == null ? this.ratioOverMedian : d;
        Map<String, Float> similarities = getSimilarities(tokenize(str));
        Float findMedian = findMedian(similarities);
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<String, Float> entry : similarities.entrySet()) {
            double floatValue = findMedian.floatValue() != 0.0f ? entry.getValue().floatValue() / findMedian.floatValue() : 100.0d;
            if (arrayList.size() >= i || floatValue < d2.doubleValue()) {
                break;
            }
            arrayList.add(entry.getKey());
        }
        return arrayList;
    }

    public List<String> tokenize(String str) {
        ArrayList arrayList = new ArrayList();
        TokenStream tokenStream = getAnalyzer().tokenStream((String) null, new StringReader(str));
        Token token = new Token();
        while (tokenStream.next(token) != null) {
            try {
                arrayList.add(token.termText());
            } catch (IOException e) {
                throw new IllegalStateException(e);
            }
        }
        return arrayList;
    }

    public static Map<String, Float> sortByDecreasingValue(Map<String, Float> map) {
        LinkedList<Map.Entry> linkedList = new LinkedList(map.entrySet());
        Collections.sort(linkedList, new Comparator<Map.Entry<String, Float>>() { // from class: org.nuxeo.ecm.platform.documentcategorization.categorizer.tfidf.TfIdfCategorizer.1
            @Override // java.util.Comparator
            public int compare(Map.Entry<String, Float> entry, Map.Entry<String, Float> entry2) {
                return -entry.getValue().compareTo(entry2.getValue());
            }
        });
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry entry : linkedList) {
            linkedHashMap.put(entry.getKey(), entry.getValue());
        }
        return linkedHashMap;
    }

    public static Float findMedian(Map<String, Float> map) {
        int size = map.size() / 2;
        Float valueOf = Float.valueOf(0.0f);
        Iterator<Float> it = map.values().iterator();
        while (it.hasNext()) {
            valueOf = it.next();
            int i = size;
            size--;
            if (i <= 0) {
                break;
            }
        }
        return valueOf;
    }
}
