package info.ephyra.answerselection.filters;

import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.DatasetClassifierTeacher;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.algorithms.knn.KnnLearner;
import edu.cmu.minorthird.classify.algorithms.linear.BalancedWinnow;
import edu.cmu.minorthird.classify.algorithms.linear.KWayMixtureLearner;
import edu.cmu.minorthird.classify.algorithms.linear.MarginPerceptron;
import edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner;
import edu.cmu.minorthird.classify.algorithms.linear.NaiveBayes;
import edu.cmu.minorthird.classify.algorithms.linear.NegativeBinomialLearner;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.algorithms.random.RandomElement;
import edu.cmu.minorthird.classify.algorithms.svm.SVMLearner;
import edu.cmu.minorthird.classify.algorithms.trees.AdaBoost;
import edu.cmu.minorthird.classify.algorithms.trees.DecisionTreeLearner;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.experiments.CrossValidatedDataset;
import edu.cmu.minorthird.classify.experiments.Evaluation;
import edu.cmu.minorthird.ui.Recommended;
import info.ephyra.io.MsgPrinter;
import info.ephyra.questionanalysis.AnalyzedQuestion;
import info.ephyra.search.Result;
import info.ephyra.util.ArrayUtils;
import info.ephyra.util.FileUtils;
import info.ephyra.util.StringUtils;
import java.io.EOFException;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:info/ephyra/answerselection/filters/ScoreNormalizationFilter.class */
public class ScoreNormalizationFilter extends Filter {
    private static final String KNN_M = "KNN";
    private static final String KWAY_MIXTURE_M = "KWayMixture";
    private static final String SVM_M = "SVM";
    private static final int NUM_FOLDS = 3;
    private static Classifier classifier;
    private static final String SCORE_F = "Score";
    private static final String EXTRACTORS_F = "Extractors";
    private static final String ANSWER_TYPES_F = "AnswerTypes";
    private static final String NUM_ANSWERS_F = "NumAnswers";
    private static final String MEAN_SCORE_F = "MeanScore";
    private static final String MAX_SCORE_F = "MaxScore";
    private static final String MIN_SCORE_F = "MinScore";
    private static final String[] ALL_FEATURES = {SCORE_F, EXTRACTORS_F, ANSWER_TYPES_F, NUM_ANSWERS_F, MEAN_SCORE_F, MAX_SCORE_F, MIN_SCORE_F};
    private static final String[] SELECTED_FEATURES = {SCORE_F, EXTRACTORS_F};
    private static int NUM_BOOSTS = 70;
    private static String ADA_BOOST_N_M = "AdaBoost" + NUM_BOOSTS;
    private static final String ADA_BOOST_10_M = "AdaBoost10";
    private static final String ADA_BOOST_100_M = "AdaBoost100";
    private static final String ADA_BOOST_L_M = "AdaBoostL";
    private static final String BALANCED_WINNOW_M = "BalancedWinnow";
    private static final String DECISION_TREE_M = "DecisionTree";
    private static final String MARGIN_PERCEPTRON_M = "MarginPerceptron";
    private static final String MAX_ENT_M = "MaxEnt";
    private static final String NAIVE_BAYES_M = "NaiveBayes";
    private static final String NEGATIVE_BINOMIAL_M = "NegativeBinomial";
    private static final String VOTED_PERCEPTRON_M = "VotedPerceptron";
    private static final String[] ALL_MODELS = {ADA_BOOST_10_M, ADA_BOOST_100_M, ADA_BOOST_N_M, ADA_BOOST_L_M, BALANCED_WINNOW_M, DECISION_TREE_M, MARGIN_PERCEPTRON_M, MAX_ENT_M, NAIVE_BAYES_M, NEGATIVE_BINOMIAL_M, VOTED_PERCEPTRON_M};
    private static final String SELECTED_MODEL = ADA_BOOST_N_M;

    private static Result[] readSerializedResults(File file) {
        ArrayList arrayList = new ArrayList();
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(file));
            if (!(objectInputStream.readObject() instanceof AnalyzedQuestion)) {
                MsgPrinter.printErrorMsg("First serialized object is not anAnalyzedQuestion.");
                System.exit(1);
            }
            while (true) {
                try {
                    arrayList.add((Result) objectInputStream.readObject());
                } catch (EOFException e) {
                    objectInputStream.close();
                    return (Result[]) arrayList.toArray(new Result[arrayList.size()]);
                }
            }
        } catch (Exception e2) {
            MsgPrinter.printErrorMsg("Could not read serialized results:");
            MsgPrinter.printErrorMsg(e2.toString());
            System.exit(1);
            return (Result[]) arrayList.toArray(new Result[arrayList.size()]);
        }
    }

    private static void addScoreFeature(MutableInstance mutableInstance, Result result) {
        mutableInstance.addNumeric(new Feature(SCORE_F), result.getScore());
    }

    private static void addExtractorFeature(MutableInstance mutableInstance, Result result) {
        mutableInstance.addBinary(new Feature(result.getExtractionTechniques()[0]));
    }

    private static void addAnswerTypeFeatures(MutableInstance mutableInstance, Result result) {
        for (String str : result.getQuery().getAnalyzedQuestion().getAnswerTypes()) {
            mutableInstance.addBinary(new Feature(str.split("->")));
        }
    }

    private static void addNumAnswersFeature(MutableInstance mutableInstance, Result result, Result[] resultArr) {
        int i = 0;
        for (Result result2 : resultArr) {
            if (result2.getScore() > 0.0f && result2.getScore() < Float.POSITIVE_INFINITY) {
                i++;
            }
        }
        mutableInstance.addNumeric(new Feature(NUM_ANSWERS_F), i);
    }

    private static void addMeanScoreFeature(MutableInstance mutableInstance, Result result, Result[] resultArr) {
        double d = 0.0d;
        int i = 0;
        for (Result result2 : resultArr) {
            if (result2.getScore() > 0.0f && result2.getScore() < Float.POSITIVE_INFINITY) {
                d += result2.getScore();
                i++;
            }
        }
        mutableInstance.addNumeric(new Feature(MEAN_SCORE_F), d / i);
    }

    private static void addMaxScoreFeature(MutableInstance mutableInstance, Result result, Result[] resultArr) {
        double d = 0.0d;
        for (Result result2 : resultArr) {
            if (result2.getScore() > 0.0f && result2.getScore() < Float.POSITIVE_INFINITY) {
                d = Math.max(result2.getScore(), d);
            }
        }
        mutableInstance.addNumeric(new Feature(MAX_SCORE_F), d);
    }

    private static void addMinScoreFeature(MutableInstance mutableInstance, Result result, Result[] resultArr) {
        double d = Double.POSITIVE_INFINITY;
        for (Result result2 : resultArr) {
            if (result2.getScore() > 0.0f && result2.getScore() < Float.POSITIVE_INFINITY) {
                d = Math.min(result2.getScore(), d);
            }
        }
        mutableInstance.addNumeric(new Feature(MIN_SCORE_F), d);
    }

    private static void addSelectedFeatures(MutableInstance mutableInstance, String[] strArr, Result result, Result[] resultArr) {
        HashSet hashSet = new HashSet();
        for (String str : strArr) {
            hashSet.add(str);
        }
        if (hashSet.contains(SCORE_F)) {
            addScoreFeature(mutableInstance, result);
        }
        if (hashSet.contains(EXTRACTORS_F)) {
            addExtractorFeature(mutableInstance, result);
        }
        if (hashSet.contains(ANSWER_TYPES_F)) {
            addAnswerTypeFeatures(mutableInstance, result);
        }
        if (hashSet.contains(NUM_ANSWERS_F)) {
            addNumAnswersFeature(mutableInstance, result, resultArr);
        }
        if (hashSet.contains(MEAN_SCORE_F)) {
            addMeanScoreFeature(mutableInstance, result, resultArr);
        }
        if (hashSet.contains(MAX_SCORE_F)) {
            addMaxScoreFeature(mutableInstance, result, resultArr);
        }
        if (hashSet.contains(MIN_SCORE_F)) {
            addMinScoreFeature(mutableInstance, result, resultArr);
        }
    }

    private static Instance createInstance(String[] strArr, Result result, Result[] resultArr) {
        MutableInstance mutableInstance = new MutableInstance(result);
        addSelectedFeatures(mutableInstance, strArr, result, resultArr);
        return mutableInstance;
    }

    private static Instance createInstance(String[] strArr, Result result, Result[] resultArr, String str) {
        MutableInstance mutableInstance = new MutableInstance(result, str);
        addSelectedFeatures(mutableInstance, strArr, result, resultArr);
        return mutableInstance;
    }

    private static Example createExample(String[] strArr, Result result, Result[] resultArr, String str) {
        return new Example(createInstance(strArr, result, resultArr, str), new ClassLabel(result.isCorrect() ? "POS" : "NEG"));
    }

    private static Dataset createDataset(String[] strArr, String str) {
        BasicDataset basicDataset = new BasicDataset();
        for (File file : FileUtils.getFilesRec(str)) {
            String name = file.getName();
            if (name.endsWith(".serialized")) {
                String replace = name.replace(".serialized", "");
                Result[] readSerializedResults = readSerializedResults(file);
                for (Result result : readSerializedResults) {
                    if (result.getScore() > 0.0f && result.getScore() != Float.POSITIVE_INFINITY && result.getExtractionTechniques() != null && result.getExtractionTechniques().length == 1) {
                        basicDataset.add(createExample(strArr, result, readSerializedResults, replace));
                    }
                }
            }
        }
        return basicDataset;
    }

    private static ClassifierLearner createLearner(String str) {
        AdaBoost adaBoost = null;
        if (str.equals(ADA_BOOST_10_M)) {
            adaBoost = new AdaBoost();
        } else if (str.equals(ADA_BOOST_100_M)) {
            adaBoost = new Recommended.BoostedStumpLearner();
        } else if (str.equals(ADA_BOOST_N_M)) {
            adaBoost = new AdaBoost(new DecisionTreeLearner(), NUM_BOOSTS);
        } else if (str.equals(ADA_BOOST_L_M)) {
            adaBoost = new AdaBoost.L();
        } else if (str.equals(BALANCED_WINNOW_M)) {
            adaBoost = new BalancedWinnow();
        } else if (str.equals(DECISION_TREE_M)) {
            adaBoost = new DecisionTreeLearner();
        } else if (str.equals(KNN_M)) {
            adaBoost = new KnnLearner();
        } else if (str.equals(KWAY_MIXTURE_M)) {
            adaBoost = new KWayMixtureLearner();
        } else if (str.equals(MARGIN_PERCEPTRON_M)) {
            adaBoost = new MarginPerceptron();
        } else if (str.equals(MAX_ENT_M)) {
            adaBoost = new MaxEntLearner();
        } else if (str.equals(NAIVE_BAYES_M)) {
            adaBoost = new NaiveBayes();
        } else if (str.equals(NEGATIVE_BINOMIAL_M)) {
            adaBoost = new NegativeBinomialLearner();
        } else if (str.equals(SVM_M)) {
            adaBoost = new SVMLearner();
        } else if (str.equals(VOTED_PERCEPTRON_M)) {
            adaBoost = new VotedPerceptron();
        } else {
            MsgPrinter.printErrorMsg("Unknown model: " + str);
            System.exit(1);
        }
        return adaBoost;
    }

    private static String createReport(String[] strArr, String[] strArr2, String str, Evaluation evaluation, long j) {
        String str2 = String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf("") + "Parameters:\n") + "-----------\n") + "Data set: " + StringUtils.concat(strArr, ", ") + " (" + evaluation.numExamples() + " examples)\n") + "Features: " + StringUtils.concat(strArr2, ", ") + "\n") + "Model:    " + str + "\n") + "\n") + "Statistics:\n") + "-----------\n";
        double[] summaryStatistics = evaluation.summaryStatistics();
        String[] summaryStatisticNames = evaluation.summaryStatisticNames();
        int i = 0;
        for (String str3 : summaryStatisticNames) {
            i = Math.max(str3.length(), i);
        }
        for (int i2 = 0; i2 < summaryStatisticNames.length; i2++) {
            str2 = String.valueOf(String.valueOf(String.valueOf(str2) + summaryStatisticNames[i2] + ": ") + StringUtils.repeat(" ", i - summaryStatisticNames[i2].length())) + summaryStatistics[i2] + "\n";
        }
        return String.valueOf(String.valueOf(String.valueOf(str2) + "Runtime: ") + StringUtils.repeat(" ", i - 7)) + j + " ms\n";
    }

    public static Classifier train(String str) {
        return train(str, SELECTED_FEATURES, SELECTED_MODEL);
    }

    public static Classifier train(String str, String[] strArr, String str2) {
        Dataset createDataset = createDataset(strArr, str);
        return new DatasetClassifierTeacher(createDataset).train(createLearner(str2));
    }

    public static Evaluation evaluate(String str, String[] strArr, String str2) {
        return new CrossValidatedDataset(createLearner(str2), createDataset(strArr, str), new CrossValSplitter(new RandomElement(System.currentTimeMillis()), 3), true).getEvaluation();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v9, types: [java.lang.String[], java.lang.String[][]] */
    public static String[][] evaluateAll(String str, String str2) {
        Object[][] nonemptySubsets = ArrayUtils.getNonemptySubsets(ALL_FEATURES);
        String[] strArr = new String[nonemptySubsets.length];
        for (int i = 0; i < nonemptySubsets.length; i++) {
            strArr[i] = new String[nonemptySubsets[i].length];
            for (int i2 = 0; i2 < nonemptySubsets[i].length; i2++) {
                strArr[i][i2] = (String) nonemptySubsets[i][i2];
            }
        }
        double d = -1.0d;
        ?? r0 = new String[2];
        for (String[] strArr2 : strArr) {
            for (String str3 : ALL_MODELS) {
                String[] visibleSubDirs = FileUtils.getVisibleSubDirs(str);
                File file = new File(str2, String.valueOf(str3) + "_" + StringUtils.concat(strArr2, "+") + "_" + StringUtils.concat(visibleSubDirs, "+"));
                if (file.exists()) {
                    MsgPrinter.printErrorMsg("File " + file + " already exists.");
                } else {
                    String str4 = "Evaluating model " + str3 + " with feature(s) " + StringUtils.concat(strArr2, ", ") + " (" + MsgPrinter.getTimestamp() + ")...";
                    MsgPrinter.printStatusMsg(StringUtils.repeat("-", str4.length()));
                    MsgPrinter.printStatusMsg(str4);
                    MsgPrinter.printStatusMsg(StringUtils.repeat("-", str4.length()));
                    long currentTimeMillis = System.currentTimeMillis();
                    Evaluation evaluate = evaluate(str, strArr2, str3);
                    try {
                        FileUtils.writeString(createReport(visibleSubDirs, strArr2, str3, evaluate, System.currentTimeMillis() - currentTimeMillis), file, "UTF-8");
                    } catch (IOException e) {
                        MsgPrinter.printErrorMsg("Failed to write report to file " + file + ":");
                        MsgPrinter.printErrorMsg(e.toString());
                        System.exit(1);
                    }
                    double f1 = evaluate.f1();
                    if (f1 > d) {
                        d = f1;
                        r0[0] = strArr2;
                        String[] strArr3 = new String[1];
                        strArr3[0] = str3;
                        r0[1] = strArr3;
                    }
                }
            }
        }
        return r0;
    }

    public static void main(String[] strArr) {
        MsgPrinter.enableStatusMsgs(true);
        MsgPrinter.enableErrorMsgs(true);
        if (strArr.length < 2) {
            MsgPrinter.printUsage("java ScoreNormalizationFilter serialized_results_dir output_dir");
            System.exit(1);
        }
        String str = strArr[0];
        String str2 = strArr[1];
        String[] strArr2 = SELECTED_FEATURES;
        String str3 = SELECTED_MODEL;
        String str4 = "Training classifier using model " + str3 + " with feature(s) " + StringUtils.concat(strArr2, ", ") + " (" + MsgPrinter.getTimestamp() + ")...";
        MsgPrinter.printStatusMsg(StringUtils.repeat("-", str4.length()));
        MsgPrinter.printStatusMsg(str4);
        MsgPrinter.printStatusMsg(StringUtils.repeat("-", str4.length()));
        Classifier train = train(str, strArr2, str3);
        String path = new File(str2, "classifiers").getPath();
        String str5 = String.valueOf(str3) + "_" + StringUtils.concat(strArr2, "+") + "_" + StringUtils.concat(FileUtils.getVisibleSubDirs(str), "+") + ".serialized";
        try {
            FileUtils.writeSerialized(train, new File(path, str5));
        } catch (IOException e) {
            MsgPrinter.printErrorMsg("Failed to serialize classifier to file " + str5 + ":");
            MsgPrinter.printErrorMsg(e.toString());
            System.exit(1);
        }
        MsgPrinter.printStatusMsg("...done.");
    }

    public static void loadClassifier(String str) {
        try {
            classifier = (Classifier) FileUtils.readSerialized(new File(str));
        } catch (Exception e) {
            MsgPrinter.printErrorMsg("Failed to load classifier:");
            MsgPrinter.printErrorMsg(e.toString());
        }
    }

    public ScoreNormalizationFilter(String str) {
        loadClassifier(str);
    }

    public Result[] preserveOrderResorting(Result[] resultArr) {
        ArrayList arrayList = new ArrayList();
        Hashtable hashtable = new Hashtable();
        for (Result result : resultArr) {
            if (result.getScore() <= 0.0f || result.getScore() == Float.POSITIVE_INFINITY || result.getExtractionTechniques() == null || result.getExtractionTechniques().length != 1) {
                arrayList.add(result);
            } else {
                String str = result.getExtractionTechniques()[0];
                ArrayList arrayList2 = (ArrayList) hashtable.get(str);
                if (arrayList2 == null) {
                    arrayList2 = new ArrayList();
                    hashtable.put(str, arrayList2);
                }
                arrayList2.add(result);
            }
        }
        for (List list : hashtable.values()) {
            Result[] apply = new NormalizedScoreSorterFilter().apply((Result[]) list.toArray(new Result[list.size()]));
            float[] fArr = new float[apply.length];
            for (int i = 0; i < apply.length; i++) {
                fArr[i] = apply[i].getNormScore();
            }
            Result[] apply2 = new ScoreSorterFilter().apply(apply);
            for (int i2 = 0; i2 < apply2.length; i2++) {
                apply2[i2].setNormScore(fArr[i2]);
            }
            for (Result result2 : apply2) {
                arrayList.add(result2);
            }
        }
        return (Result[]) arrayList.toArray(new Result[arrayList.size()]);
    }

    public Result[] preserveOrderAveraging(Result[] resultArr) {
        Hashtable hashtable = new Hashtable();
        for (Result result : resultArr) {
            if (result.getScore() > 0.0f && result.getScore() != Float.POSITIVE_INFINITY && result.getExtractionTechniques() != null && result.getExtractionTechniques().length == 1) {
                String str = result.getExtractionTechniques()[0];
                ArrayList arrayList = (ArrayList) hashtable.get(str);
                if (arrayList == null) {
                    arrayList = new ArrayList();
                    hashtable.put(str, arrayList);
                }
                arrayList.add(result);
            }
        }
        for (List list : hashtable.values()) {
            double d = 0.0d;
            float f = 0.0f;
            Iterator it = list.iterator();
            while (it.hasNext()) {
                float score = ((Result) it.next()).getScore();
                d += r0.getNormScore() / score;
                if (score > f) {
                    f = score;
                }
            }
            double size = d / list.size();
            Iterator it2 = list.iterator();
            while (it2.hasNext()) {
                ((Result) it2.next()).setNormScore((float) (r0.getScore() * size));
            }
        }
        return resultArr;
    }

    public Result[] preserveOrderTop(Result[] resultArr) {
        Hashtable hashtable = new Hashtable();
        for (Result result : resultArr) {
            if (result.getScore() > 0.0f && result.getScore() != Float.POSITIVE_INFINITY && result.getExtractionTechniques() != null && result.getExtractionTechniques().length == 1) {
                String str = result.getExtractionTechniques()[0];
                ArrayList arrayList = (ArrayList) hashtable.get(str);
                if (arrayList == null) {
                    arrayList = new ArrayList();
                    hashtable.put(str, arrayList);
                }
                arrayList.add(result);
            }
        }
        for (List<Result> list : hashtable.values()) {
            float f = 0.0f;
            float f2 = 0.0f;
            for (Result result2 : list) {
                float score = result2.getScore();
                float normScore = result2.getNormScore();
                if (score > f) {
                    f = score;
                    f2 = normScore;
                }
            }
            double d = f2 / f;
            Iterator it = list.iterator();
            while (it.hasNext()) {
                ((Result) it.next()).setNormScore((float) (r0.getScore() * d));
            }
        }
        return resultArr;
    }

    @Override // info.ephyra.answerselection.filters.Filter
    public Result[] apply(Result[] resultArr) {
        if (classifier == null) {
            return resultArr;
        }
        for (Result result : resultArr) {
            if (result.getScore() > 0.0f && result.getScore() != Float.POSITIVE_INFINITY && result.getExtractionTechniques() != null && result.getExtractionTechniques().length == 1) {
                result.setNormScore((float) classifier.classification(createInstance(SELECTED_FEATURES, result, resultArr)).posProbability());
            }
        }
        return preserveOrderTop(resultArr);
    }
}
