package org.deeplearning4j.bagofwords.vectorizer;

import java.io.BufferedReader;
import java.io.File;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.documentiterator.DocumentIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelsSource;
import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizer.class */
public class TfidfVectorizer extends BaseTextVectorizer {
    private static final Logger log = LoggerFactory.getLogger(TfidfVectorizer.class);

    /* loaded from: input_file:org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizer$Builder.class */
    public static class Builder {
        protected TokenizerFactory tokenizerFactory;
        protected LabelAwareIterator iterator;
        protected int minWordFrequency;
        protected VocabCache<VocabWord> vocabCache;
        protected LabelsSource labelsSource = new LabelsSource();
        protected Collection<String> stopWords = new ArrayList();
        protected boolean isParallel = true;

        public Builder allowParallelTokenization(boolean z) {
            this.isParallel = z;
            return this;
        }

        public Builder setTokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) {
            if (tokenizerFactory == null) {
                throw new NullPointerException("tokenizerFactory is marked @NonNull but is null");
            }
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder setIterator(@NonNull LabelAwareIterator labelAwareIterator) {
            if (labelAwareIterator == null) {
                throw new NullPointerException("iterator is marked @NonNull but is null");
            }
            this.iterator = labelAwareIterator;
            return this;
        }

        public Builder setIterator(@NonNull DocumentIterator documentIterator) {
            if (documentIterator == null) {
                throw new NullPointerException("iterator is marked @NonNull but is null");
            }
            this.iterator = new DocumentIteratorConverter(documentIterator, this.labelsSource);
            return this;
        }

        public Builder setIterator(@NonNull SentenceIterator sentenceIterator) {
            if (sentenceIterator == null) {
                throw new NullPointerException("iterator is marked @NonNull but is null");
            }
            this.iterator = new SentenceIteratorConverter(sentenceIterator, this.labelsSource);
            return this;
        }

        public Builder setVocab(@NonNull VocabCache<VocabWord> vocabCache) {
            if (vocabCache == null) {
                throw new NullPointerException("vocab is marked @NonNull but is null");
            }
            this.vocabCache = vocabCache;
            return this;
        }

        public Builder setMinWordFrequency(int i) {
            this.minWordFrequency = i;
            return this;
        }

        public Builder setStopWords(Collection<String> collection) {
            this.stopWords = collection;
            return this;
        }

        public TfidfVectorizer build() {
            TfidfVectorizer tfidfVectorizer = new TfidfVectorizer();
            tfidfVectorizer.tokenizerFactory = this.tokenizerFactory;
            tfidfVectorizer.iterator = this.iterator;
            tfidfVectorizer.minWordFrequency = this.minWordFrequency;
            tfidfVectorizer.labelsSource = this.labelsSource;
            tfidfVectorizer.isParallel = this.isParallel;
            if (this.vocabCache == null) {
                this.vocabCache = new AbstractCache.Builder().build();
            }
            tfidfVectorizer.vocabCache = this.vocabCache;
            tfidfVectorizer.stopWords = this.stopWords;
            return tfidfVectorizer;
        }
    }

    @Override // org.deeplearning4j.bagofwords.vectorizer.TextVectorizer
    public DataSet vectorize(InputStream inputStream, String str) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream, "UTF-8"));
            StringBuilder sb = new StringBuilder();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return vectorize(sb.toString(), str);
                }
                sb.append(readLine);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.bagofwords.vectorizer.TextVectorizer
    public DataSet vectorize(String str, String str2) {
        return new DataSet(transform(str), FeatureUtil.toOutcomeVector(this.labelsSource.indexOf(str2), this.labelsSource.size()));
    }

    @Override // org.deeplearning4j.bagofwords.vectorizer.TextVectorizer
    public DataSet vectorize(File file, String str) {
        try {
            return vectorize(FileUtils.readFileToString(file), str);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.bagofwords.vectorizer.TextVectorizer
    public INDArray transform(String str) {
        return transform(this.tokenizerFactory.create(str).getTokens());
    }

    @Override // org.deeplearning4j.bagofwords.vectorizer.TextVectorizer
    public INDArray transform(List<String> list) {
        INDArray create = Nd4j.create(1, this.vocabCache.numWords());
        HashMap hashMap = new HashMap();
        for (String str : list) {
            if (!hashMap.containsKey(str)) {
                hashMap.put(str, new AtomicLong(0L));
            }
            ((AtomicLong) hashMap.get(str)).incrementAndGet();
        }
        for (int i = 0; i < list.size(); i++) {
            int indexOf = this.vocabCache.indexOf(list.get(i));
            if (indexOf >= 0) {
                create.putScalar(indexOf, tfidfWord(list.get(i), ((AtomicLong) hashMap.get(list.get(i))).longValue(), list.size()));
            }
        }
        return create;
    }

    public double tfidfWord(String str, long j, long j2) {
        return MathUtils.tfidf(tfForWord(j, j2), idfForWord(str));
    }

    private double tfForWord(long j, long j2) {
        return j / j2;
    }

    private double idfForWord(String str) {
        return MathUtils.idf(this.vocabCache.totalNumberOfDocs(), this.vocabCache.docAppearedIn(str));
    }

    public DataSet vectorize() {
        return null;
    }
}
