package info.ephyra.answerselection.definitional;

import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.DatasetClassifierTeacher;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.algorithms.knn.KnnLearner;
import edu.cmu.minorthird.classify.algorithms.linear.BalancedWinnow;
import edu.cmu.minorthird.classify.algorithms.linear.MarginPerceptron;
import edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner;
import edu.cmu.minorthird.classify.algorithms.linear.NaiveBayes;
import edu.cmu.minorthird.classify.algorithms.linear.NegativeBinomialLearner;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.algorithms.random.RandomElement;
import edu.cmu.minorthird.classify.algorithms.svm.SVMLearner;
import edu.cmu.minorthird.classify.algorithms.trees.AdaBoost;
import edu.cmu.minorthird.classify.algorithms.trees.DecisionTreeLearner;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.experiments.CrossValidatedDataset;
import edu.cmu.minorthird.classify.experiments.Evaluation;
import edu.cmu.minorthird.ui.Recommended;
import info.ephyra.io.MsgPrinter;
import info.ephyra.questionanalysis.TermExpander;
import info.ephyra.util.FileUtils;
import info.ephyra.util.StringUtils;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;

/* loaded from: input_file:info/ephyra/answerselection/definitional/AnswerSelectorLearner.class */
public class AnswerSelectorLearner {
    private static final String ADA_BOOST_10_M = "AdaBoost10";
    private static final String ADA_BOOST_100_M = "AdaBoost100";
    private static final String ADA_BOOST_LOG_M = "AdaBoostLog";
    private static final String BALANCED_WINNOW_M = "BalancedWinnow";
    private static final String DECISION_TREE_M = "DecisionTree";
    private static final String KNN_M = "KNN";
    private static final String MARGIN_PERCEPTRON_M = "MarginPerceptron";
    private static final String MAX_ENT_M = "MaxEnt";
    private static final String NAIVE_BAYES_M = "NaiveBayes";
    private static final String NEGATIVE_BINOMIAL_M = "NegativeBinomial";
    private static final String VOTED_PERCEPTRON_M = "VotedPerceptron";
    private static final int FOLDS = 10;
    private static final String[] ALL_FEATURES = {FeatureExtractor.CHARACTER_LENGTH_F, FeatureExtractor.CONTAINS_3P_PRONOUN_F, FeatureExtractor.POS_MIX_DISTANCE_F, FeatureExtractor.RANK_F, FeatureExtractor.RATIO_FUNCTION_WORDS_F, FeatureExtractor.RATIO_NE_TOKENS_F, FeatureExtractor.RATIO_TARGET_WORDS_F, FeatureExtractor.TOKEN_LENGTH_F, FeatureExtractor.WEB_TERM_COVERAGE_F};
    private static final String[] SELECTED_FEATURES = {FeatureExtractor.CHARACTER_LENGTH_F, FeatureExtractor.CONTAINS_3P_PRONOUN_F, FeatureExtractor.POS_MIX_DISTANCE_F, FeatureExtractor.RANK_F, FeatureExtractor.RATIO_FUNCTION_WORDS_F, FeatureExtractor.RATIO_NE_TOKENS_F, FeatureExtractor.RATIO_TARGET_WORDS_F, FeatureExtractor.TOKEN_LENGTH_F, FeatureExtractor.WEB_TERM_COVERAGE_F};
    private static int ROUNDS = 10;
    private static String ADA_BOOST_N_M = "AdaBoost" + ROUNDS;
    private static final String SVM_M = "SVM";
    private static final String[] ALL_MODELS = {SVM_M};
    private static final String SELECTED_MODEL = ADA_BOOST_N_M;

    private static Instance createInstance(DefinitionalAnswer definitionalAnswer, String[] strArr) {
        MutableInstance mutableInstance = new MutableInstance(definitionalAnswer, definitionalAnswer.getQuestionID());
        HashSet hashSet = new HashSet();
        for (String str : strArr) {
            hashSet.add(str);
        }
        double[] features = definitionalAnswer.getFeatures();
        String[] featureDomains = definitionalAnswer.getFeatureDomains();
        String[] featureDescs = definitionalAnswer.getFeatureDescs();
        for (int i = 0; i < features.length; i++) {
            if (hashSet.contains(featureDescs[i])) {
                Feature feature = new Feature(featureDescs[i]);
                if (featureDomains[i].equals(FeatureExtractor.BINARY_D)) {
                    if (features[i] > TermExpander.MIN_EXPANSION_WEIGHT) {
                        mutableInstance.addBinary(feature);
                    } else if (featureDomains[i].equals(FeatureExtractor.NUMERIC_D)) {
                        mutableInstance.addNumeric(feature, features[i]);
                    } else if (featureDomains[i].equals(FeatureExtractor.CONTINUOUS_D)) {
                    }
                }
                mutableInstance.addNumeric(feature, features[i]);
            }
        }
        return mutableInstance;
    }

    private static Example createExample(DefinitionalAnswer definitionalAnswer, String[] strArr) {
        return new Example(createInstance(definitionalAnswer, strArr), new ClassLabel(definitionalAnswer.coversNuggets() ? "POS" : "NEG"));
    }

    private static Dataset createDataset(DefinitionalAnswer[] definitionalAnswerArr, String[] strArr) {
        BasicDataset basicDataset = new BasicDataset();
        for (DefinitionalAnswer definitionalAnswer : definitionalAnswerArr) {
            basicDataset.add(createExample(definitionalAnswer, strArr));
        }
        return basicDataset;
    }

    private static ClassifierLearner createLearner(String str) {
        AdaBoost adaBoost = null;
        if (str.equals(ADA_BOOST_10_M)) {
            adaBoost = new AdaBoost();
        } else if (str.equals(ADA_BOOST_100_M)) {
            adaBoost = new Recommended.BoostedStumpLearner();
        } else if (str.equals(ADA_BOOST_N_M)) {
            adaBoost = new AdaBoost(new DecisionTreeLearner(), ROUNDS);
        } else if (str.equals(ADA_BOOST_LOG_M)) {
            adaBoost = new AdaBoost.L();
        } else if (str.equals(BALANCED_WINNOW_M)) {
            adaBoost = new BalancedWinnow();
        } else if (str.equals(DECISION_TREE_M)) {
            adaBoost = new DecisionTreeLearner();
        } else if (str.equals(KNN_M)) {
            adaBoost = new KnnLearner();
        } else if (str.equals(MARGIN_PERCEPTRON_M)) {
            adaBoost = new MarginPerceptron();
        } else if (str.equals(MAX_ENT_M)) {
            adaBoost = new MaxEntLearner();
        } else if (str.equals(NAIVE_BAYES_M)) {
            adaBoost = new NaiveBayes();
        } else if (str.equals(NEGATIVE_BINOMIAL_M)) {
            adaBoost = new NegativeBinomialLearner();
        } else if (str.equals(SVM_M)) {
            adaBoost = new SVMLearner();
        } else if (str.equals(VOTED_PERCEPTRON_M)) {
            adaBoost = new VotedPerceptron();
        } else {
            MsgPrinter.printErrorMsg("Unknown model: " + str);
            System.exit(1);
        }
        return adaBoost;
    }

    private static Hashtable<String, Double> calculateNSRRs(Evaluation evaluation) {
        int numExamples = evaluation.numExamples();
        Hashtable hashtable = new Hashtable();
        for (int i = 0; i < numExamples; i++) {
            Evaluation.Entry entry = evaluation.getEntry(i);
            DefinitionalAnswer definitionalAnswer = (DefinitionalAnswer) entry.instance.getSource();
            definitionalAnswer.setRelevance(entry.predicted.posProbability());
            String str = String.valueOf(definitionalAnswer.getQuestionID()) + "_" + definitionalAnswer.getRunID();
            ArrayList arrayList = (ArrayList) hashtable.get(str);
            if (arrayList == null) {
                arrayList = new ArrayList();
                hashtable.put(str, arrayList);
            }
            arrayList.add(definitionalAnswer);
        }
        Hashtable hashtable2 = new Hashtable();
        Hashtable hashtable3 = new Hashtable();
        Hashtable hashtable4 = new Hashtable();
        Iterator it = hashtable.keySet().iterator();
        while (it.hasNext()) {
            ArrayList arrayList2 = (ArrayList) hashtable.get((String) it.next());
            double d = 0.0d;
            int i2 = 0;
            DefinitionalAnswer[] definitionalAnswerArr = (DefinitionalAnswer[]) arrayList2.toArray(new DefinitionalAnswer[arrayList2.size()]);
            for (int i3 = 0; i3 < definitionalAnswerArr.length; i3++) {
                if (definitionalAnswerArr[i3].coversNuggets()) {
                    d += 1.0d / (i3 + 1);
                    i2++;
                }
            }
            if (i2 != 0) {
                double d2 = 0.0d;
                for (int i4 = 0; i4 < i2; i4++) {
                    d2 += 1.0d / (i4 + 1);
                }
                double d3 = d / d2;
                double d4 = 0.0d;
                Arrays.sort(definitionalAnswerArr, new Comparator<DefinitionalAnswer>() { // from class: info.ephyra.answerselection.definitional.AnswerSelectorLearner.1
                    @Override // java.util.Comparator
                    public int compare(DefinitionalAnswer definitionalAnswer2, DefinitionalAnswer definitionalAnswer3) {
                        double relevance = definitionalAnswer3.getRelevance() - definitionalAnswer2.getRelevance();
                        if (relevance < TermExpander.MIN_EXPANSION_WEIGHT) {
                            return -1;
                        }
                        return relevance > TermExpander.MIN_EXPANSION_WEIGHT ? 1 : 0;
                    }
                });
                for (int i5 = 0; i5 < definitionalAnswerArr.length; i5++) {
                    if (definitionalAnswerArr[i5].coversNuggets()) {
                        d4 += 1.0d / (i5 + 1);
                    }
                }
                double d5 = d4 / d2;
                String runID = ((DefinitionalAnswer) arrayList2.get(0)).getRunID();
                Integer num = (Integer) hashtable2.get(runID);
                if (num == null) {
                    num = new Integer(0);
                }
                hashtable2.put(runID, Integer.valueOf(num.intValue() + 1));
                Double d6 = (Double) hashtable3.get(runID);
                if (d6 == null) {
                    d6 = new Double(TermExpander.MIN_EXPANSION_WEIGHT);
                }
                hashtable3.put(runID, Double.valueOf(d6.doubleValue() + d3));
                Double d7 = (Double) hashtable4.get(runID);
                if (d7 == null) {
                    d7 = new Double(TermExpander.MIN_EXPANSION_WEIGHT);
                }
                hashtable4.put(runID, Double.valueOf(d7.doubleValue() + d5));
            }
        }
        for (String str2 : hashtable2.keySet()) {
            int intValue = ((Integer) hashtable2.get(str2)).intValue();
            hashtable3.put(str2, Double.valueOf(((Double) hashtable3.get(str2)).doubleValue() / intValue));
            hashtable4.put(str2, Double.valueOf(((Double) hashtable4.get(str2)).doubleValue() / intValue));
        }
        int size = hashtable2.size();
        double d8 = 0.0d;
        Iterator it2 = hashtable3.values().iterator();
        while (it2.hasNext()) {
            d8 += ((Double) it2.next()).doubleValue();
        }
        double d9 = d8 / size;
        double d10 = 0.0d;
        Iterator it3 = hashtable4.values().iterator();
        while (it3.hasNext()) {
            d10 += ((Double) it3.next()).doubleValue();
        }
        double d11 = d10 / size;
        Hashtable<String, Double> hashtable5 = new Hashtable<>();
        for (String str3 : hashtable2.keySet()) {
            hashtable5.put(String.valueOf(str3) + " NSRR (original)", (Double) hashtable3.get(str3));
            hashtable5.put(String.valueOf(str3) + " NSRR (reranked)", (Double) hashtable4.get(str3));
        }
        hashtable5.put("Overall NSRR (original)", Double.valueOf(d9));
        hashtable5.put("Overall NSRR (reranked)", Double.valueOf(d11));
        return hashtable5;
    }

    private static String createReport(String[] strArr, String[] strArr2, String str, Evaluation evaluation, Hashtable<String, Double> hashtable, long j) {
        String str2 = String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf("") + "Parameters:\n") + "-----------\n") + "Model:    " + str + "\n") + "Features: " + StringUtils.concat(strArr2, ", ") + "\n") + "Data set: " + StringUtils.concat(strArr, ", ") + " (" + evaluation.numExamples() + " examples)\n") + "Folds:    10\n") + "\n") + "Overall Statistics:\n") + "-------------------\n";
        double[] summaryStatistics = evaluation.summaryStatistics();
        String[] summaryStatisticNames = evaluation.summaryStatisticNames();
        int i = 0;
        for (int i2 = 0; i2 < summaryStatistics.length; i2++) {
            i = Math.max(summaryStatisticNames[i2].length(), i);
        }
        for (int i3 = 0; i3 < summaryStatistics.length; i3++) {
            str2 = String.valueOf(String.valueOf(String.valueOf(str2) + summaryStatisticNames[i3] + ": ") + StringUtils.repeat(" ", i - summaryStatisticNames[i3].length())) + summaryStatistics[i3] + "\n";
        }
        String str3 = String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(str2) + "Overall NSRR (original): ") + StringUtils.repeat(" ", i - "Overall NSRR (original)".length())) + hashtable.get("Overall NSRR (original)") + "\n") + "Overall NSRR (reranked): ") + StringUtils.repeat(" ", i - "Overall NSRR (reranked)".length())) + hashtable.get("Overall NSRR (reranked)") + "\n") + "Runtime: ") + StringUtils.repeat(" ", i - 7)) + j + " ms\n") + "\n") + "Per-Run Statistics:\n") + "-------------------\n";
        String[] strArr3 = (String[]) hashtable.keySet().toArray(new String[hashtable.size()]);
        Arrays.sort(strArr3, new Comparator<String>() { // from class: info.ephyra.answerselection.definitional.AnswerSelectorLearner.2
            @Override // java.util.Comparator
            public int compare(String str4, String str5) {
                String[] split = str4.split("_");
                String[] split2 = str5.split("_");
                int length = split.length - split2.length;
                if (length != 0) {
                    return length;
                }
                for (int i4 = 0; i4 < split.length; i4++) {
                    int length2 = split[i4].length() - split2[i4].length();
                    if (length2 != 0) {
                        return length2;
                    }
                    int compareTo = split[i4].compareTo(split2[i4]);
                    if (compareTo != 0) {
                        return compareTo;
                    }
                }
                return 0;
            }
        });
        for (String str4 : strArr3) {
            if (!str4.equals("Overall NSRR (original)") && !str4.equals("Overall NSRR (reranked)")) {
                str3 = String.valueOf(String.valueOf(String.valueOf(str3) + str4 + ": ") + StringUtils.repeat(" ", i - str4.length())) + hashtable.get(str4) + "\n";
            }
        }
        return str3;
    }

    public static Classifier train(DefinitionalAnswer[] definitionalAnswerArr, String[] strArr, String str) {
        Dataset createDataset = createDataset(definitionalAnswerArr, strArr);
        return new DatasetClassifierTeacher(createDataset).train(createLearner(str));
    }

    public static Classifier train(DefinitionalAnswer[] definitionalAnswerArr) {
        return train(definitionalAnswerArr, SELECTED_FEATURES, SELECTED_MODEL);
    }

    public static Evaluation evaluate(DefinitionalAnswer[] definitionalAnswerArr, String[] strArr, String str) {
        return new CrossValidatedDataset(createLearner(str), createDataset(definitionalAnswerArr, strArr), new CrossValSplitter(new RandomElement(System.currentTimeMillis()), 10), true).getEvaluation();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v4, types: [java.lang.String[], java.lang.String[][]] */
    public static String[][] evaluateAll(DefinitionalAnswer[] definitionalAnswerArr, String[] strArr, String str) {
        double d = Double.NEGATIVE_INFINITY;
        ?? r0 = new String[2];
        for (String[] strArr2 : new String[]{new String[]{FeatureExtractor.POS_MIX_DISTANCE_F, FeatureExtractor.RANK_F, FeatureExtractor.RATIO_FUNCTION_WORDS_F, FeatureExtractor.RATIO_TARGET_WORDS_F}}) {
            for (String str2 : ALL_MODELS) {
                String str3 = String.valueOf(str2) + "_" + StringUtils.concat(strArr2, "+") + "_" + StringUtils.concat(strArr, "+") + "_10Fold";
                File file = new File(str, str3);
                if (file.exists()) {
                    MsgPrinter.printErrorMsg("File " + str3 + " already exists.");
                } else {
                    String str4 = "Evaluating model " + str2 + " with feature(s) " + StringUtils.concat(strArr2, ", ") + " (" + MsgPrinter.getTimestamp() + ")...";
                    MsgPrinter.printStatusMsg(StringUtils.repeat("-", str4.length()));
                    MsgPrinter.printStatusMsg(str4);
                    MsgPrinter.printStatusMsg(StringUtils.repeat("-", str4.length()));
                    long currentTimeMillis = System.currentTimeMillis();
                    Evaluation evaluate = evaluate(definitionalAnswerArr, strArr2, str2);
                    long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
                    Hashtable<String, Double> calculateNSRRs = calculateNSRRs(evaluate);
                    try {
                        FileUtils.writeString(createReport(strArr, strArr2, str2, evaluate, calculateNSRRs, currentTimeMillis2), file, "UTF-8");
                    } catch (IOException e) {
                        MsgPrinter.printErrorMsg("Could not write to file " + str3 + ":");
                        MsgPrinter.printErrorMsg(e.toString());
                        System.exit(1);
                    }
                    double doubleValue = calculateNSRRs.get("Overall NSRR (reranked)").doubleValue();
                    if (doubleValue > d) {
                        d = doubleValue;
                        r0[0] = strArr2;
                        String[] strArr3 = new String[1];
                        strArr3[0] = str2;
                        r0[1] = strArr3;
                    }
                }
            }
        }
        return r0;
    }

    public static void main(String[] strArr) {
        MsgPrinter.enableStatusMsgs(true);
        MsgPrinter.enableErrorMsgs(true);
        if (strArr.length < 2) {
            MsgPrinter.printUsage("java AnswerSelectorLearner serialized_answers_file output_dir");
            System.exit(1);
        }
        String str = strArr[0];
        String str2 = strArr[1];
        DefinitionalAnswer[] deserializeAnswers = DefinitionalParser.deserializeAnswers(str);
        String[] split = new File(str).getName().split("(\\+|\\..++)");
        String[][] evaluateAll = evaluateAll(deserializeAnswers, split, new File(str2, "reports").getPath());
        String[] strArr2 = evaluateAll[0];
        String str3 = evaluateAll[1][0];
        String str4 = "Training classifier using model " + str3 + " with feature(s) " + StringUtils.concat(strArr2, ", ") + " (" + MsgPrinter.getTimestamp() + ")...";
        MsgPrinter.printStatusMsg(StringUtils.repeat("-", str4.length()));
        MsgPrinter.printStatusMsg(str4);
        MsgPrinter.printStatusMsg(StringUtils.repeat("-", str4.length()));
        try {
            FileUtils.writeSerialized(train(deserializeAnswers, strArr2, str3), new File(new File(str2, "classifiers").getPath(), String.valueOf(str3) + "_" + StringUtils.concat(strArr2, "+") + "_" + StringUtils.concat(split, "+") + ".serialized"));
        } catch (IOException e) {
            MsgPrinter.printErrorMsg("Could not write to file " + str + ":");
            MsgPrinter.printErrorMsg(e.toString());
            System.exit(1);
        }
    }
}
