Skip to content

Commit

Permalink
#8 - added precision, recall, F1 report for CV and TrainTest
Browse files Browse the repository at this point in the history
  • Loading branch information
Horsmann committed Feb 20, 2017
1 parent 819ef04 commit 08fce2e
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import org.dkpro.tc.ml.report.BatchCrossValidationReport;

import de.unidue.ltl.flextag.core.reports.CvAvgAccuracyReport;
import de.unidue.ltl.flextag.core.reports.CvAvgPerWordClassReport;
import de.unidue.ltl.flextag.core.reports.CvAvgPosTagPrecisionRecallF1;

public class FlexTagCrossValidation
extends FlexTagSetUp
Expand All @@ -55,7 +55,7 @@ private List<Class<? extends Report>> initCrossValidationReports()
List<Class<? extends Report>> r = new ArrayList<>();
r.add(BatchCrossValidationReport.class);
r.add(CvAvgAccuracyReport.class);
r.add(CvAvgPerWordClassReport.class);
r.add(CvAvgPosTagPrecisionRecallF1.class);
return r;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import org.dkpro.tc.ml.ExperimentTrainTest;
import org.dkpro.tc.ml.report.BatchTrainTestReport;

import de.unidue.ltl.flextag.core.reports.TtAccuracyPerWordClassReport;
import de.unidue.ltl.flextag.core.reports.TtPoSTagPrecisionRecallF1;
import de.unidue.ltl.flextag.core.reports.TtAccuracyReport;

public class FlexTagTrainTest
Expand All @@ -56,7 +56,7 @@ private List<Class<? extends Report>> initTrainTestReports()
List<Class<? extends Report>> r = new ArrayList<>();
r.add(BatchTrainTestReport.class);
r.add(TtAccuracyReport.class);
r.add(TtAccuracyPerWordClassReport.class);
r.add(TtPoSTagPrecisionRecallF1.class);
return r;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

Expand All @@ -35,7 +36,7 @@
/**
* Determines the accuracy for each word class
*/
public class CvAvgPerWordClassReport
public class CvAvgPosTagPrecisionRecallF1
extends BatchReportBase
implements Constants
{
Expand Down Expand Up @@ -75,25 +76,34 @@ public void execute()
}

StringBuilder sb = new StringBuilder();
sb.append(String.format("%20s\t%8s\t%5s%n", "PoS", "Occr.", "Acc"));
sb.append(String.format("#%10s\t%5s\t%5s\t%5s\t%5s%n", "Class", "Occr", "Prec.", "Reca.",
"F1"));

List<String> keySet = new ArrayList<>(map.keySet());
Collections.sort(keySet);
for (String k : keySet) {
List<WordClass> list = map.get(k);

Double N = new Double(0);
double acc = 0;
long N = 0;
double precision = 0;
double recall = 0;
double f1 = 0;

for (WordClass wc : list) {
N += wc.getN();
acc += (wc.getCorrect() / wc.getN());
N += wc.frequency;
precision = wc.precision;
recall = wc.recall;
f1 = wc.f1;
}
N /= list.size();
acc /= list.size();

sb.append(String.format("%20s\t%8d\t%5s\n", k, N.intValue(),
String.format("%3.1f", acc * 100)));
precision /= list.size();
recall /= list.size();
f1 /= list.size();

sb.append(String.format("%10s", k) + "\t" + String.format("%5d", N) + "\t"
+ String.format("%5.2f", precision) + "\t" + String.format("%5.2f", recall)
+ "\t" + String.format("%5.2f", f1) + "\n");

}

File locateKey = storageService.locateKey(subcontext.getId(), OUTPUT_FILE);
Expand Down Expand Up @@ -135,44 +145,73 @@ private List<String> getFoldersOfSingleRuns(File attributesTXT)
}

private Map<String, WordClass> getWcPerformances(File locateKey)
throws IOException
{
Map<String, WordClass> wcp = new HashMap<>();
throws IOException
{
Map<String, WordClass> wcp = new HashMap<>();

List<String> lines = FileUtils.readLines(locateKey);
Map<String, String> labels = getLabels(lines);
List<String> lines = FileUtils.readLines(locateKey);
Map<String, String> labels = getLabels(lines);

for (String l : lines) {
if (l.startsWith("#")) {
continue;
}
String[] entry = splitAtFirstEqualSignRightHandSide(l);
List<String> predictions = new ArrayList<>();
List<String> gold = new ArrayList<>();

String pg = entry[1];
String[] split = pg.split(";");
for (String l : lines) {
if (l.startsWith("#")) {
continue;
}
String[] entry = splitAtFirstEqualSignRightHandSide(l);

if (split.length < 2) {
System.out.println("ERROR\t" + l);
continue;
}
String pg = entry[1];
String[] split = pg.split(";");

if (split.length < 2) {
System.out.println("ERROR\t" + l);
continue;
}

String prediction = labels.get(split[0]);
String gold = labels.get(split[1]);
String p = labels.get(split[0]);
String g = labels.get(split[1]);

WordClass wordClass = wcp.get(gold);
if (wordClass == null) {
wordClass = new WordClass();
}
if (gold.equals(prediction)) {
wordClass.incrementCorrect();
predictions.add(p);
gold.add(g);
}
else {
wordClass.incrementIncorrect();

List<String> allGoldTags = new ArrayList<>(new HashSet<>(gold));
Collections.sort(allGoldTags);

for (String t : allGoldTags) {
double tp = 0, fp = 0, tn = 0, fn = 0;
long freq = 0;
for (int i = 0; i < gold.size(); i++) {
String g = gold.get(i);
String p = predictions.get(i);

if (!g.equals(t) && !p.equals(t)) {
tn++;
}
else if (!g.equals(t) && p.equals(t)) {
fp++;
}
else if (g.equals(t) && !p.equals(t)) {
fn++;
}
else if (g.equals(t) && p.equals(t)) {
tp++;
}

if (t.equals(g)) {
freq++;
}
}

double recall = tp / (tp + fp);
double precision = tp / (tp + fn);
double f1 = (2 * (precision * recall)) / (precision + recall);

wcp.put(t, new WordClass(precision, recall, f1, freq));
}
wcp.put(gold, wordClass);
return wcp;
}
return wcp;
}

private String[] splitAtFirstEqualSignRightHandSide(String l)
{
Expand Down Expand Up @@ -210,30 +249,5 @@ private Map<String, String> extractIdLabelMap(String s)
return id2label;
}

class WordClass
{
double correct = 0;
double incorrect = 0;

public Double getN()
{
return correct + incorrect;
}

public Double getCorrect()
{
return correct;
}

public void incrementCorrect()
{
correct++;
}

public void incrementIncorrect()
{
incorrect++;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

Expand All @@ -35,7 +36,7 @@
/**
* Determines the accuracy for each word class
*/
public class TtAccuracyPerWordClassReport
public class TtPoSTagPrecisionRecallF1
extends BatchReportBase
implements Constants
{
Expand Down Expand Up @@ -67,13 +68,14 @@ private String generateWordClassReport(File locateKey)
List<String> keySet = new ArrayList<String>(wcp.keySet());
Collections.sort(keySet);

sb.append(String.format("#%10s\t%5s\t%5s%n", "Class", "Occr", "Accr"));
sb.append(String.format("#%10s\t%5s\t%5s\t%5s\t%5s%n", "Class", "Occr", "Prec.", "Reca.",
"F1"));
for (String k : keySet) {
WordClass wc = wcp.get(k);
double accuracy = wc.getCorrect() / wc.getN() * 100;

sb.append(String.format("%10s", k) + "\t" + String.format("%5d", wc.getN().intValue())
+ "\t" + String.format("%5.1f%n", accuracy));
sb.append(String.format("%10s", k) + "\t" + String.format("%5d", wc.frequency) + "\t"
+ String.format("%5.2f", wc.precision) + "\t" + String.format("%5.2f", wc.recall)
+ "\t" + String.format("%5.2f", wc.f1) + "\n");
}

return sb.toString();
Expand All @@ -88,6 +90,9 @@ private Map<String, WordClass> getWcPerformances(File locateKey)
List<String> lines = FileUtils.readLines(locateKey);
Map<String, String> labels = getLabels(lines);

List<String> predictions = new ArrayList<>();
List<String> gold = new ArrayList<>();

for (String l : lines) {
if (l.startsWith("#")) {
continue;
Expand All @@ -102,20 +107,46 @@ private Map<String, WordClass> getWcPerformances(File locateKey)
continue;
}

String prediction = labels.get(split[0]);
String gold = labels.get(split[1]);
String p = labels.get(split[0]);
String g = labels.get(split[1]);

WordClass wordClass = wcp.get(gold);
if (wordClass == null) {
wordClass = new WordClass();
}
if (gold.equals(prediction)) {
wordClass.incrementCorrect();
}
else {
wordClass.incrementIncorrect();
predictions.add(p);
gold.add(g);
}

List<String> allGoldTags = new ArrayList<>(new HashSet<>(gold));
Collections.sort(allGoldTags);

for (String t : allGoldTags) {
double tp = 0, fp = 0, tn = 0, fn = 0;
long freq = 0;
for (int i = 0; i < gold.size(); i++) {
String g = gold.get(i);
String p = predictions.get(i);

if (!g.equals(t) && !p.equals(t)) {
tn++;
}
else if (!g.equals(t) && p.equals(t)) {
fp++;
}
else if (g.equals(t) && !p.equals(t)) {
fn++;
}
else if (g.equals(t) && p.equals(t)) {
tp++;
}

if (t.equals(g)) {
freq++;
}
}
wcp.put(gold, wordClass);

double recall = tp / (tp + fp);
double precision = tp / (tp + fn);
double f1 = (2 * (precision * recall)) / (precision + recall);

wcp.put(t, new WordClass(precision, recall, f1, freq));
}
return wcp;
}
Expand Down Expand Up @@ -156,30 +187,5 @@ private Map<String, String> extractIdLabelMap(String s)
return id2label;
}

class WordClass
{
double correct = 0;
double incorrect = 0;

public Double getN()
{
return correct + incorrect;
}

public Double getCorrect()
{
return correct;
}

public void incrementCorrect()
{
correct++;
}

public void incrementIncorrect()
{
incorrect++;
}
}

}
Loading

0 comments on commit 08fce2e

Please sign in to comment.