package info.ephyra.questionanalysis.atype.minorthird.hierarchical;

import edu.cmu.lti.javelin.util.FileUtil;
import edu.cmu.lti.javelin.util.Language;
import edu.cmu.lti.javelin.util.MLToolkit;
import edu.cmu.lti.util.Pair;
import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.CascadingBinaryLearner;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.DatasetClassifierTeacher;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.MostFrequentFirstLearner;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.OneVsAllLearner;
import edu.cmu.minorthird.classify.algorithms.knn.KnnLearner;
import edu.cmu.minorthird.classify.algorithms.linear.BalancedWinnow;
import edu.cmu.minorthird.classify.algorithms.linear.KWayMixtureLearner;
import edu.cmu.minorthird.classify.algorithms.linear.MarginPerceptron;
import edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner;
import edu.cmu.minorthird.classify.algorithms.linear.NaiveBayes;
import edu.cmu.minorthird.classify.algorithms.linear.NegativeBinomialLearner;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.algorithms.random.RandomElement;
import edu.cmu.minorthird.classify.algorithms.svm.SVMLearner;
import edu.cmu.minorthird.classify.algorithms.trees.AdaBoost;
import edu.cmu.minorthird.classify.algorithms.trees.DecisionTreeLearner;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.experiments.CrossValidatedDataset;
import edu.cmu.minorthird.classify.experiments.Evaluation;
import edu.cmu.minorthird.classify.experiments.Tester;
import edu.cmu.minorthird.util.IOUtil;
import edu.cmu.minorthird.util.gui.ViewerFrame;
import info.ephyra.questionanalysis.TermExpander;
import info.ephyra.questionanalysis.atype.extractor.FeatureExtractor;
import info.ephyra.questionanalysis.atype.extractor.FeatureExtractorFactory;
import info.ephyra.util.Properties;
import java.io.File;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.Date;
import java.util.Formatter;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Locale;
import libsvm.svm_parameter;
import org.apache.log4j.Logger;

/* loaded from: input_file:info/ephyra/questionanalysis/atype/minorthird/hierarchical/HierarchicalClassifierTrainer.class */
public class HierarchicalClassifierTrainer {
    private static Logger log = Logger.getLogger(HierarchicalClassifierTrainer.class);
    protected FeatureExtractor extractor;
    protected String trainingFile;
    protected String testingFile;
    protected int crossValidationFolds;
    protected String[] learnerNames;
    protected boolean useClassLevels;
    protected HashSet<String> classLabels;
    protected HashSet<String> trainingLabels;
    protected HashSet<String> featureTypes;
    protected boolean loadTraining;
    protected String classifierDir;
    protected Dataset trainingSet;
    protected Dataset testingSet;
    protected Classifier classifier;
    protected Pair<Language, Language> languagePair;
    protected Properties properties;
    protected CrossValidatedDataset cvDataset;
    protected Evaluation evaluation;
    protected long runTime;

    public HierarchicalClassifierTrainer(Pair<Language, Language> pair) {
        this.languagePair = pair;
    }

    public void setProperties(Properties properties) {
        for (String str : properties.keySet()) {
            this.properties.put(str, properties.get(str));
        }
        try {
            initialize();
        } catch (Exception e) {
            log.error("Error re-initializing: ", e);
        }
    }

    public void initialize() throws Exception {
        if (this.languagePair == null) {
            throw new Exception("Langauage pair must be set before calling initialize");
        }
        if (this.properties == null) {
            this.properties = Properties.loadFromClassName(getClass().getName());
            this.properties = this.properties.mapProperties().get(this.languagePair.getFirst() + "_" + this.languagePair.getSecond());
            this.extractor = FeatureExtractorFactory.getInstance(this.languagePair.getFirst());
        }
        this.trainingFile = this.properties.getProperty("trainingFile");
        this.testingFile = this.properties.getProperty("testingFile");
        this.crossValidationFolds = Integer.parseInt(this.properties.getProperty("crossValidationFolds"));
        this.learnerNames = this.properties.getProperty("learners").split(",");
        for (int i = 0; i < this.learnerNames.length; i++) {
            this.learnerNames[i] = this.learnerNames[i].trim();
        }
        this.useClassLevels = Boolean.parseBoolean(this.properties.getProperty("useClassLevels"));
        if (!this.useClassLevels && this.learnerNames.length > 1) {
            this.learnerNames = (String[]) Arrays.asList(this.learnerNames).subList(0, 1).toArray(new String[1]);
        }
        this.classLabels = new HashSet<>();
        String[] split = this.properties.getProperty("classLabels").split(",");
        for (int i2 = 0; i2 < split.length; i2++) {
            split[i2] = HierarchicalClassifier.getHierarchicalClassName(split[i2], this.learnerNames.length, this.useClassLevels);
            this.classLabels.add(split[i2]);
        }
        this.featureTypes = new HashSet<>();
        for (String str : this.properties.getProperty("featureTypes").split(",")) {
            this.featureTypes.add(str.trim());
        }
        this.classifierDir = this.properties.getProperty("classifierDir");
        this.trainingSet = makeDataset(this.trainingFile);
        if (this.crossValidationFolds < 0) {
            this.testingSet = makeDataset(this.testingFile);
        }
    }

    private Dataset makeDataset(String str) {
        if (this.trainingLabels == null) {
            this.loadTraining = true;
            this.trainingLabels = new HashSet<>();
        }
        BasicDataset basicDataset = new BasicDataset();
        this.extractor.setUseClassLevels(this.useClassLevels);
        this.extractor.setClassLevels(this.learnerNames.length);
        Example[] loadFile = this.extractor.loadFile(str);
        for (int i = 0; i < loadFile.length; i++) {
            String bestClassName = loadFile[i].getLabel().bestClassName();
            if (this.classLabels.contains(bestClassName)) {
                MutableInstance mutableInstance = new MutableInstance(loadFile[i].getSource(), loadFile[i].getSubpopulationId());
                Feature.Looper binaryFeatureIterator = loadFile[i].binaryFeatureIterator();
                while (binaryFeatureIterator.hasNext()) {
                    Feature nextFeature = binaryFeatureIterator.nextFeature();
                    if (this.featureTypes.contains(nextFeature.getPart(0))) {
                        mutableInstance.addBinary(nextFeature);
                    }
                }
                Feature.Looper numericFeatureIterator = loadFile[i].numericFeatureIterator();
                while (numericFeatureIterator.hasNext()) {
                    Feature nextFeature2 = numericFeatureIterator.nextFeature();
                    if (this.featureTypes.contains(nextFeature2.getPart(0))) {
                        mutableInstance.addNumeric(nextFeature2, loadFile[i].getWeight(nextFeature2));
                    }
                }
                Example example = new Example(mutableInstance, loadFile[i].getLabel());
                MLToolkit.println(example);
                if (this.loadTraining) {
                    this.trainingLabels.add(bestClassName);
                    basicDataset.add(example);
                } else if (this.trainingLabels.contains(bestClassName)) {
                    basicDataset.add(example);
                } else {
                    MLToolkit.println("Label of test example not found in training set (discarding): " + bestClassName);
                }
            } else {
                MLToolkit.println("Discarding example for Class: " + bestClassName);
            }
        }
        if (this.loadTraining) {
            this.loadTraining = false;
        }
        MLToolkit.println("Loaded " + basicDataset.size() + " examples for experiment from " + str);
        return basicDataset;
    }

    public HierarchicalClassifierLearner createHierarchicalClassifierLearner(String[] strArr) {
        ClassifierLearner[] classifierLearnerArr = new ClassifierLearner[strArr.length];
        for (int i = 0; i < classifierLearnerArr.length; i++) {
            classifierLearnerArr[i] = createLearnerByName(strArr[i]);
        }
        return new HierarchicalClassifierLearner(classifierLearnerArr);
    }

    public ClassifierLearner createLearnerByName(String str) {
        KnnLearner knnLearner;
        if (str.equalsIgnoreCase("KNN")) {
            knnLearner = new KnnLearner();
        } else if (str.equalsIgnoreCase("KWAY_MIX")) {
            knnLearner = new KWayMixtureLearner();
        } else if (str.equalsIgnoreCase("MAX_ENT")) {
            knnLearner = new MaxEntLearner();
        } else if (str.equalsIgnoreCase("BWINNOW_OVA")) {
            knnLearner = new OneVsAllLearner(new BalancedWinnow());
        } else if (str.equalsIgnoreCase("MPERCEPTRON_OVA")) {
            knnLearner = new OneVsAllLearner(new MarginPerceptron());
        } else if (str.equalsIgnoreCase("NBAYES_OVA")) {
            knnLearner = new OneVsAllLearner(new NaiveBayes());
        } else if (str.equalsIgnoreCase("VPERCEPTRON_OVA")) {
            knnLearner = new OneVsAllLearner(new VotedPerceptron());
        } else if (str.equalsIgnoreCase("ADABOOST_OVA")) {
            knnLearner = new OneVsAllLearner(new AdaBoost());
        } else if (str.equalsIgnoreCase("ADABOOST_CB")) {
            knnLearner = new CascadingBinaryLearner(new AdaBoost());
        } else if (str.equalsIgnoreCase("ADABOOST_MFF")) {
            knnLearner = new MostFrequentFirstLearner(new AdaBoost());
        } else if (str.equalsIgnoreCase("ADABOOSTL_OVA")) {
            knnLearner = new OneVsAllLearner(new AdaBoost.L());
        } else if (str.equalsIgnoreCase("ADABOOSTL_CB")) {
            knnLearner = new CascadingBinaryLearner(new AdaBoost.L());
        } else if (str.equalsIgnoreCase("ADABOOSTL_MFF")) {
            knnLearner = new MostFrequentFirstLearner(new AdaBoost.L());
        } else if (str.equalsIgnoreCase("DTREE_OVA")) {
            knnLearner = new OneVsAllLearner(new DecisionTreeLearner());
        } else if (str.equalsIgnoreCase("DTREE_CB")) {
            knnLearner = new CascadingBinaryLearner(new DecisionTreeLearner());
        } else if (str.equalsIgnoreCase("DTREE_MFF")) {
            knnLearner = new MostFrequentFirstLearner(new DecisionTreeLearner());
        } else if (str.equalsIgnoreCase("NEGBI_OVA")) {
            knnLearner = new OneVsAllLearner(new NegativeBinomialLearner());
        } else if (str.equalsIgnoreCase("NEGBI_CB")) {
            knnLearner = new CascadingBinaryLearner(new NegativeBinomialLearner());
        } else if (str.equalsIgnoreCase("NEGBI_MFF")) {
            knnLearner = new MostFrequentFirstLearner(new NegativeBinomialLearner());
        } else if (str.equalsIgnoreCase("SVM_OVA")) {
            knnLearner = new OneVsAllLearner(new SVMLearner());
        } else if (str.equalsIgnoreCase("SVM_OVA_CONF1")) {
            svm_parameter svm_parameterVar = new svm_parameter();
            svm_parameterVar.svm_type = 0;
            svm_parameterVar.kernel_type = 1;
            svm_parameterVar.degree = 2.0d;
            svm_parameterVar.gamma = 1.0d;
            svm_parameterVar.coef0 = TermExpander.MIN_EXPANSION_WEIGHT;
            svm_parameterVar.nu = 0.5d;
            svm_parameterVar.cache_size = 40.0d;
            svm_parameterVar.C = 1.0d;
            svm_parameterVar.eps = 0.001d;
            svm_parameterVar.p = 0.1d;
            svm_parameterVar.shrinking = 1;
            svm_parameterVar.nr_weight = 0;
            svm_parameterVar.weight_label = new int[0];
            svm_parameterVar.weight = new double[0];
            knnLearner = new OneVsAllLearner(new SVMLearner(svm_parameterVar));
        } else if (str.equalsIgnoreCase("SVM_CB")) {
            knnLearner = new CascadingBinaryLearner(new SVMLearner());
        } else if (str.equalsIgnoreCase("SVM_MFF")) {
            knnLearner = new MostFrequentFirstLearner(new SVMLearner());
        } else {
            System.err.println("Unrecognized learner name: " + str);
            knnLearner = null;
        }
        return knnLearner;
    }

    public Evaluation runExperiment() {
        this.runTime = System.currentTimeMillis();
        HierarchicalClassifierLearner createHierarchicalClassifierLearner = createHierarchicalClassifierLearner(this.learnerNames);
        if (this.crossValidationFolds < 0) {
            this.evaluation = Tester.evaluate(createHierarchicalClassifierLearner, this.trainingSet, this.testingSet);
        } else {
            this.cvDataset = new CrossValidatedDataset(createHierarchicalClassifierLearner, this.trainingSet, new CrossValSplitter(new RandomElement(System.currentTimeMillis()), this.crossValidationFolds), true);
            this.evaluation = this.cvDataset.getEvaluation();
        }
        this.runTime = System.currentTimeMillis() - this.runTime;
        return this.evaluation;
    }

    public void trainClassifier() {
        this.runTime = System.currentTimeMillis();
        this.classifier = new DatasetClassifierTeacher(this.trainingSet).train(createHierarchicalClassifierLearner(this.learnerNames));
        this.runTime = System.currentTimeMillis() - this.runTime;
    }

    public void saveClassifier(String str) {
        try {
            IOUtil.saveSerialized(this.classifier, new File(str));
        } catch (Exception e) {
            e.printStackTrace(System.err);
        }
    }

    public void saveClassifier() {
        String str = String.valueOf(this.classifierDir) + (System.currentTimeMillis() / 1000);
        for (int i = 0; i < this.learnerNames.length; i++) {
            str = String.valueOf(str) + "-" + this.learnerNames[i];
        }
        if (this.useClassLevels) {
            str = String.valueOf(str) + "-HC";
        }
        saveClassifier(String.valueOf(str) + "-" + new File(this.trainingFile).getName());
    }

    public void loadClassifier(String str) {
        try {
            this.classifier = IOUtil.loadSerialized(new File(str));
        } catch (Exception e) {
            e.printStackTrace(System.err);
        }
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    public String createReport() {
        DecimalFormat decimalFormat = new DecimalFormat("#0.00");
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Question Answer Type Classification Report\n");
        stringBuffer.append(new Date() + "\n");
        stringBuffer.append("\n");
        stringBuffer.append("Training Data File: " + this.trainingFile + "\n");
        if (this.crossValidationFolds < 0) {
            stringBuffer.append("Testing Data File: " + this.testingFile + "\n");
        } else {
            stringBuffer.append("Testing using " + this.crossValidationFolds + "-fold cross validation\n");
        }
        stringBuffer.append("\n");
        stringBuffer.append("Valid Class Labels:");
        Iterator<String> it = this.classLabels.iterator();
        while (it.hasNext()) {
            stringBuffer.append(" " + it.next());
        }
        stringBuffer.append("\n");
        stringBuffer.append("\n");
        if (this.useClassLevels) {
            stringBuffer.append("Using Hierarchical Classifier Learners:\n");
            for (int i = 0; i < this.learnerNames.length; i++) {
                stringBuffer.append("\t" + this.learnerNames[i] + "\n");
            }
        } else {
            stringBuffer.append("Using Simple Classifier Learner: " + this.learnerNames[0] + "\n");
        }
        stringBuffer.append("\n");
        stringBuffer.append("Feature Selection:\n");
        Iterator<String> it2 = this.featureTypes.iterator();
        while (it2.hasNext()) {
            stringBuffer.append("\t" + it2.next() + "\n");
        }
        stringBuffer.append("\n");
        stringBuffer.append("Experiment Results:\n");
        stringBuffer.append("\n");
        stringBuffer.append("\tAccuracy: " + decimalFormat.format((1.0d - this.evaluation.errorRate()) * 100.0d) + "% [" + this.evaluation.numExamples() + " example(s)]\n");
        stringBuffer.append("\n");
        stringBuffer.append("\tAccuracy by Class:\n");
        stringBuffer.append("\n");
        String[] classes = this.evaluation.getClasses();
        double[] numberOfExamplesByClass = this.evaluation.numberOfExamplesByClass();
        double[] errorRateByClass = this.evaluation.errorRateByClass();
        double d = 0.0d;
        for (int i2 = 0; i2 < classes.length; i2++) {
            double d2 = (1.0d - errorRateByClass[i2]) * 100.0d;
            stringBuffer.append("\t\t" + classes[i2] + " " + decimalFormat.format(d2) + "% [" + ((int) numberOfExamplesByClass[i2]) + " example(s)]\n");
            d += d2;
        }
        stringBuffer.append("\n");
        stringBuffer.append("\tAverage Class Accuracy: " + decimalFormat.format(d / classes.length) + "%\n");
        stringBuffer.append("\n");
        stringBuffer.append("Run Time: " + this.runTime + " ms\n");
        stringBuffer.append("\n");
        stringBuffer.append("Confusion Matrix:\n");
        stringBuffer.append("\n");
        stringBuffer.append(prettyPrintCM(this.evaluation.confusionMatrix(), this.evaluation.getClasses()));
        stringBuffer.append("\n");
        return stringBuffer.toString();
    }

    private String prettyPrintCM(Evaluation.Matrix matrix, String[] strArr) {
        double[][] dArr = matrix.values;
        String[] strArr2 = new String[strArr.length];
        StringBuilder sb = new StringBuilder();
        Formatter formatter = new Formatter(sb, Locale.US);
        int i = 0;
        for (int i2 = 0; i2 < strArr.length; i2++) {
            strArr2[i2] = strArr[i2].replaceAll("\\B(.{1,2}).*?(.)\\b", "$1$2");
            if (strArr2[i2].length() > i) {
                i = strArr2[i2].length();
            }
        }
        String str = "%-" + (i + 1) + "s";
        formatter.format(str, "");
        for (int i3 = 0; i3 < strArr.length; i3++) {
            formatter.format(str, strArr2[i3]);
        }
        sb.append("\n\n");
        for (int i4 = 0; i4 < strArr.length; i4++) {
            formatter.format(str, strArr2[i4]);
            for (int i5 = 0; i5 < strArr.length; i5++) {
                formatter.format(str, Double.toString(dArr[i4][i5]));
            }
            sb.append("\n\n");
        }
        return sb.toString();
    }

    public static void main(String[] strArr) throws Exception {
        if (strArr.length > 3 || strArr.length < 2 || (strArr.length == 3 && !strArr[0].equals("--train"))) {
            System.err.println("Usage:");
            System.err.println("java HierarchicalClassifierTrainer [--train] <questionLang> <corpusLang>\n");
            System.err.println(" - <questionLang> and <corpusLang> must be one of the following:");
            System.err.println("     en_US, ja_JP, jp_JP, zh_TW, zh_CN");
            System.err.println(" - Outputs a trained model in the current directory if --train is used.");
            System.err.println(" - Otherwise, performs an evaluation using the configuration in the");
            System.err.println("     properties file and outputs a report describing the results.");
            System.exit(0);
        }
        boolean z = false;
        int i = 0;
        if (strArr[0].equals("--train")) {
            z = true;
            i = 0 + 1;
        }
        HierarchicalClassifierTrainer hierarchicalClassifierTrainer = new HierarchicalClassifierTrainer(new Pair(Language.valueOf(strArr[i]), Language.valueOf(strArr[i + 1])));
        hierarchicalClassifierTrainer.initialize();
        if (z) {
            System.out.println("Training classifier...");
            hierarchicalClassifierTrainer.trainClassifier();
            hierarchicalClassifierTrainer.saveClassifier();
            System.out.println("Classifier saved.");
            return;
        }
        System.out.println("Running experiment...");
        Evaluation runExperiment = hierarchicalClassifierTrainer.runExperiment();
        FileUtil.writeFile(hierarchicalClassifierTrainer.createReport(), String.valueOf(strArr[0]) + ".report" + System.currentTimeMillis() + ".txt", "UTF-8");
        new ViewerFrame(strArr[0], runExperiment.toGUI()).setVisible(true);
    }
}
