/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.metrics;

import edu.stanford.nlp.international.Languages;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.parser.lexparser.EnglishTreebankParserParams;
import edu.stanford.nlp.parser.lexparser.Lexicon;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.parser.metrics.AbstractEval;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.trees.DiskTreebank;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeTransformer;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.util.AbstractCollection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class TaggingEval
extends AbstractEval {
    private final Lexicon lex;
    private static boolean doCatLevelEval = false;
    private Counter<String> precisions;
    private Counter<String> recalls;
    private Counter<String> f1s;
    private Counter<String> precisions2;
    private Counter<String> recalls2;
    private Counter<String> pnums2;
    private Counter<String> rnums2;
    private Counter<String> percentOOV;
    private Counter<String> percentOOV2;
    private static final int minArgs = 2;
    private static final StringBuilder usage = new StringBuilder();

    public TaggingEval(String str) {
        this(str, true, null);
    }

    public TaggingEval(String str, boolean runningAverages, Lexicon lex) {
        super(str, runningAverages);
        this.lex = lex;
        if (doCatLevelEval) {
            this.precisions = new ClassicCounter<String>();
            this.recalls = new ClassicCounter<String>();
            this.f1s = new ClassicCounter<String>();
            this.precisions2 = new ClassicCounter<String>();
            this.recalls2 = new ClassicCounter<String>();
            this.pnums2 = new ClassicCounter<String>();
            this.rnums2 = new ClassicCounter<String>();
            this.percentOOV = new ClassicCounter<String>();
            this.percentOOV2 = new ClassicCounter<String>();
        }
    }

    protected Set<HasTag> makeObjects(Tree tree) {
        return tree == null ? new HashSet<HasTag>() : new HashSet<CoreLabel>(tree.taggedLabeledYield());
    }

    private Map<String, Set<Label>> makeObjectsByCat(Tree t) {
        HashMap<String, Set<Label>> catMap = new HashMap<String, Set<Label>>();
        List<CoreLabel> tly = t.taggedLabeledYield();
        for (CoreLabel label : tly) {
            if (catMap.containsKey(label.value())) {
                ((Set)catMap.get(label.value())).add(label);
                continue;
            }
            HashSet<CoreLabel> catSet = new HashSet<CoreLabel>();
            catSet.add(label);
            catMap.put(label.value(), catSet);
        }
        return catMap;
    }

    @Override
    public void evaluate(Tree guess, Tree gold, PrintWriter pw) {
        if (gold == null || guess == null) {
            System.err.printf("%s: Cannot compare against a null gold or guess tree!\n", this.getClass().getName());
            return;
        }
        super.evaluate(guess, gold, pw);
        if (doCatLevelEval) {
            Map<String, Set<Label>> guessCats = this.makeObjectsByCat(guess);
            Map<String, Set<Label>> goldCats = this.makeObjectsByCat(gold);
            HashSet<String> allCats = new HashSet<String>();
            allCats.addAll(guessCats.keySet());
            allCats.addAll(goldCats.keySet());
            for (String cat : allCats) {
                Set<Label> thisGuessCats = guessCats.get(cat);
                Set<Label> thisGoldCats = goldCats.get(cat);
                if (thisGuessCats == null) {
                    thisGuessCats = new HashSet<Label>();
                }
                if (thisGoldCats == null) {
                    thisGoldCats = new HashSet<Label>();
                }
                double currentPrecision = TaggingEval.precision(thisGuessCats, thisGoldCats);
                double currentRecall = TaggingEval.precision(thisGoldCats, thisGuessCats);
                double currentF1 = currentPrecision > 0.0 && currentRecall > 0.0 ? 2.0 / (1.0 / currentPrecision + 1.0 / currentRecall) : 0.0;
                this.precisions.incrementCount(cat, currentPrecision);
                this.recalls.incrementCount(cat, currentRecall);
                this.f1s.incrementCount(cat, currentF1);
                this.precisions2.incrementCount(cat, (double)thisGuessCats.size() * currentPrecision);
                this.pnums2.incrementCount(cat, thisGuessCats.size());
                this.recalls2.incrementCount(cat, (double)thisGoldCats.size() * currentRecall);
                this.rnums2.incrementCount(cat, thisGoldCats.size());
                if (this.lex != null) {
                    this.measureOOV(guess, gold);
                }
                if (pw == null || !this.runningAverages) continue;
                pw.println(cat + "\tP: " + (double)((int)(currentPrecision * 10000.0)) / 100.0 + " (sent ave " + (double)((int)(this.precisions.getCount(cat) * 10000.0 / this.num)) / 100.0 + ") (evalb " + (double)((int)(this.precisions2.getCount(cat) * 10000.0 / this.pnums2.getCount(cat))) / 100.0 + ")");
                pw.println("\tR: " + (double)((int)(currentRecall * 10000.0)) / 100.0 + " (sent ave " + (double)((int)(this.recalls.getCount(cat) * 10000.0 / this.num)) / 100.0 + ") (evalb " + (double)((int)(this.recalls2.getCount(cat) * 10000.0 / this.rnums2.getCount(cat))) / 100.0 + ")");
                double cF1 = 2.0 / (this.rnums2.getCount(cat) / this.recalls2.getCount(cat) + this.pnums2.getCount(cat) / this.precisions2.getCount(cat));
                String emit = this.str + " F1: " + (double)((int)(currentF1 * 10000.0)) / 100.0 + " (sent ave " + (double)((int)(10000.0 * this.f1s.getCount(cat) / this.num)) / 100.0 + ", evalb " + (double)((int)(10000.0 * cF1)) / 100.0 + ")";
                pw.println(emit);
            }
            if (pw != null && this.runningAverages) {
                pw.println("========================================");
            }
        }
    }

    private void measureOOV(Tree guess, Tree gold) {
        List<CoreLabel> goldTagging = gold.taggedLabeledYield();
        List<CoreLabel> guessTagging = guess.taggedLabeledYield();
        assert (goldTagging.size() == guessTagging.size());
        for (int i = 0; i < goldTagging.size(); ++i) {
            if (goldTagging.get(i) == guessTagging.get(i)) continue;
            this.percentOOV2.incrementCount(goldTagging.get(i).tag());
            if (this.lex.isKnown(goldTagging.get(i).word())) continue;
            this.percentOOV.incrementCount(goldTagging.get(i).tag());
        }
    }

    @Override
    public void display(boolean verbose, PrintWriter pw) {
        super.display(verbose, pw);
        if (doCatLevelEval) {
            double f1;
            double rec;
            double prec;
            double rnum2;
            double pnum2;
            DecimalFormat nf = new DecimalFormat("0.00");
            HashSet<String> cats = new HashSet<String>();
            Random rand = new Random();
            cats.addAll(this.precisions.keySet());
            cats.addAll(this.recalls.keySet());
            TreeMap<Double, String> f1Map = new TreeMap<Double, String>();
            for (String cat : cats) {
                pnum2 = this.pnums2.getCount(cat);
                rnum2 = this.rnums2.getCount(cat);
                prec = this.precisions2.getCount(cat) / pnum2;
                f1 = 2.0 / (1.0 / prec + 1.0 / (rec = this.recalls2.getCount(cat) / rnum2));
                if (new Double(f1).equals(Double.NaN)) {
                    f1 = -1.0;
                }
                if (f1Map.containsKey(f1)) {
                    f1Map.put(f1 + rand.nextDouble() / 1000.0, cat);
                    continue;
                }
                f1Map.put(f1, cat);
            }
            pw.println("============================================================");
            pw.println("Tagging Performance by Category -- final statistics");
            pw.println("============================================================");
            for (String cat : f1Map.values()) {
                pnum2 = this.pnums2.getCount(cat);
                rnum2 = this.rnums2.getCount(cat);
                prec = this.precisions2.getCount(cat) / pnum2;
                rec = this.recalls2.getCount(cat) / rnum2;
                f1 = 2.0 / (1.0 / (prec *= 100.0) + 1.0 / (rec *= 100.0));
                double oovRate = this.lex == null ? -1.0 : this.percentOOV.getCount(cat) / this.percentOOV2.getCount(cat);
                pw.println(cat + "\tLP: " + (pnum2 == 0.0 ? " N/A" : nf.format(prec)) + "\tguessed: " + (int)pnum2 + "\tLR: " + (rnum2 == 0.0 ? " N/A" : nf.format(rec)) + "\tgold:  " + (int)rnum2 + "\tF1: " + (pnum2 == 0.0 || rnum2 == 0.0 ? " N/A" : nf.format(f1)) + "\tOOV: " + (this.lex == null ? " N/A" : nf.format(oovRate)));
            }
            pw.println("============================================================");
        }
    }

    public static void main(String[] args) {
        if (args.length < 2) {
            System.out.println(usage.toString());
            System.exit(-1);
        }
        TreebankLangParserParams tlpp = new EnglishTreebankParserParams();
        int maxGoldYield = Integer.MAX_VALUE;
        int maxGuessYield = Integer.MAX_VALUE;
        boolean VERBOSE = false;
        boolean skipGuess = false;
        String guessFile = null;
        String goldFile = null;
        for (int i = 0; i < args.length; ++i) {
            if (args[i].startsWith("-")) {
                if (args[i].equals("-l")) {
                    Languages.Language lang = Languages.Language.valueOf(args[++i].trim());
                    tlpp = Languages.getLanguageParams(lang);
                    continue;
                }
                if (args[i].equals("-y")) {
                    maxGoldYield = Integer.parseInt(args[++i].trim());
                    continue;
                }
                if (args[i].equals("-v")) {
                    VERBOSE = true;
                    continue;
                }
                if (args[i].equals("-c")) {
                    doCatLevelEval = true;
                    continue;
                }
                if (args[i].equals("-g")) {
                    maxGuessYield = Integer.parseInt(args[++i].trim());
                    skipGuess = true;
                    continue;
                }
                System.out.println(usage.toString());
                System.exit(-1);
                continue;
            }
            goldFile = args[i++];
            guessFile = args[i];
            break;
        }
        PrintWriter pwOut = tlpp.pw();
        DiskTreebank guessTreebank = tlpp.diskTreebank();
        guessTreebank.loadPath(guessFile);
        pwOut.println("GUESS TREEBANK:");
        pwOut.println(guessTreebank.textualSummary());
        DiskTreebank goldTreebank = tlpp.diskTreebank();
        goldTreebank.loadPath(goldFile);
        pwOut.println("GOLD TREEBANK:");
        pwOut.println(goldTreebank.textualSummary());
        TaggingEval taggingEval = new TaggingEval("Tagging LP/LR");
        TreeTransformer tc = tlpp.collinizer();
        Iterator goldItr = ((AbstractCollection)goldTreebank).iterator();
        int goldLineId = 0;
        int skippedGuessTrees = 0;
        block1: for (Tree guess : guessTreebank) {
            Tree evalGuess = tc.transformTree(guess);
            if (guess.yield().size() > maxGuessYield) {
                ++skippedGuessTrees;
                continue;
            }
            boolean doneEval = false;
            while (goldItr.hasNext() && !doneEval) {
                Tree gold = (Tree)goldItr.next();
                Tree evalGold = tc.transformTree(gold);
                ++goldLineId;
                if (gold.yield().size() > maxGoldYield) continue;
                if (gold.yield().size() != guess.yield().size()) {
                    ++skippedGuessTrees;
                    pwOut.println("Yield mismatch at gold line " + goldLineId);
                    continue block1;
                }
                taggingEval.evaluate(evalGuess, evalGold, VERBOSE ? pwOut : null);
                doneEval = true;
            }
        }
        pwOut.println("================================================================================");
        if (skippedGuessTrees != 0) {
            pwOut.printf("%s %d guess trees\n", skipGuess ? "Skipped" : "Unable to evaluate", skippedGuessTrees);
        }
        taggingEval.display(true, pwOut);
        pwOut.println();
        pwOut.close();
    }

    static {
        usage.append(String.format("Usage: java %s [OPTS] gold guess\n\n", TaggingEval.class.getName()));
        usage.append("Options:\n");
        usage.append("  -v         : Verbose mode.\n");
        usage.append("  -l lang    : Select language settings from " + Languages.listOfLanguages() + "\n");
        usage.append("  -y num     : Skip gold trees with yields longer than num.\n");
        usage.append("  -g num     : Skip guess trees with yields longer than num.\n");
        usage.append("  -c         : Compute LP/LR/F1 by category.\n");
    }
}

