/*
 * Decompiled with CFR 0.152.
 */
package com.hankcs.hanlp.classification.classifiers;

import com.hankcs.hanlp.classification.classifiers.IClassifier;
import com.hankcs.hanlp.classification.corpus.Document;
import com.hankcs.hanlp.classification.corpus.MemoryDataSet;
import com.hankcs.hanlp.classification.models.AbstractModel;
import com.hankcs.hanlp.classification.utilities.CollectionUtility;
import com.hankcs.hanlp.classification.utilities.MathUtility;
import com.hankcs.hanlp.classification.utilities.Predefine;
import java.io.IOException;
import java.util.Map;
import java.util.TreeMap;

public abstract class AbstractClassifier
implements IClassifier {
    boolean configProbabilityEnabled = true;

    @Override
    public IClassifier enableProbability(boolean enable) {
        return this;
    }

    @Override
    public String classify(String text) throws IllegalArgumentException, IllegalStateException {
        Map<String, Double> scoreMap = this.predict(text);
        return CollectionUtility.max(scoreMap);
    }

    @Override
    public String classify(Document document) throws IllegalArgumentException, IllegalStateException {
        Map<String, Double> scoreMap = this.predict(document);
        return CollectionUtility.max(scoreMap);
    }

    @Override
    public void train(String folderPath, String charsetName) throws IOException {
        MemoryDataSet dataSet = new MemoryDataSet();
        dataSet.load(folderPath, charsetName);
        this.train(dataSet);
    }

    @Override
    public void train(Map<String, String[]> trainingDataSet) throws IllegalArgumentException {
        MemoryDataSet dataSet = new MemoryDataSet();
        Predefine.logger.start("\u6b63\u5728\u6784\u9020\u8bad\u7ec3\u6570\u636e\u96c6...", new Object[0]);
        int total = trainingDataSet.size();
        int cur = 0;
        for (Map.Entry<String, String[]> entry : trainingDataSet.entrySet()) {
            String category = entry.getKey();
            Predefine.logger.out("[%s]...", category);
            String[] stringArray = entry.getValue();
            int n = stringArray.length;
            int n2 = 0;
            while (n2 < n) {
                String doc = stringArray[n2];
                dataSet.add(category, doc);
                ++n2;
            }
            Predefine.logger.out("%.2f%%...", MathUtility.percentage(++cur, total));
        }
        Predefine.logger.finish(" \u52a0\u8f7d\u5b8c\u6bd5\n", new Object[0]);
        this.train(dataSet);
    }

    @Override
    public void train(String folderPath) throws IOException {
        this.train(folderPath, "UTF-8");
    }

    @Override
    public Map<String, Double> predict(Document document) {
        AbstractModel model = this.getModel();
        if (model == null) {
            throw new IllegalStateException("\u672a\u8bad\u7ec3\u6a21\u578b\uff01\u65e0\u6cd5\u6267\u884c\u9884\u6d4b\uff01");
        }
        if (document == null) {
            throw new IllegalArgumentException("\u53c2\u6570 text == null");
        }
        double[] probs = this.categorize(document);
        TreeMap<String, Double> scoreMap = new TreeMap<String, Double>();
        int i = 0;
        while (i < probs.length) {
            scoreMap.put(model.catalog[i], probs[i]);
            ++i;
        }
        return scoreMap;
    }

    @Override
    public int label(Document document) throws IllegalArgumentException, IllegalStateException {
        AbstractModel model = this.getModel();
        if (model == null) {
            throw new IllegalStateException("\u672a\u8bad\u7ec3\u6a21\u578b\uff01\u65e0\u6cd5\u6267\u884c\u9884\u6d4b\uff01");
        }
        if (document == null) {
            throw new IllegalArgumentException("\u53c2\u6570 text == null");
        }
        double[] probs = this.categorize(document);
        double max = Double.NEGATIVE_INFINITY;
        int best = -1;
        int i = 0;
        while (i < probs.length) {
            if (probs[i] > max) {
                max = probs[i];
                best = i;
            }
            ++i;
        }
        return best;
    }
}

