Skip to content

Commit

Permalink
#491 - Id2Outcome report might miss values in unit-mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Horsmann committed May 26, 2018
1 parent 56c31f9 commit e512938
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,20 @@
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.output.FileWriterWithEncoding;
import org.dkpro.lab.storage.StorageService.AccessMode;
import org.dkpro.tc.core.Constants;
import org.dkpro.tc.io.libsvm.AdapterFormat;
import org.dkpro.tc.io.libsvm.LibsvmDataFormatWriter;
import org.dkpro.tc.ml.report.TcBatchReportBase;
import org.dkpro.tc.ml.report.util.SortedKeyProperties;

public class LibsvmDataFormatOutcomeIdReport
extends TcBatchReportBase
Expand Down Expand Up @@ -67,8 +66,8 @@ public void execute() throws Exception
List<String> predictions = readPredictions();
Map<String, String> index2instanceIdMap = getMapping(isUnit || isSequence);

Properties prop = new SortedKeyProperties();
int lineCounter = 0;
StringBuilder sb = new StringBuilder();
for (String line : predictions) {
if (line.startsWith("#")) {
continue;
Expand All @@ -80,27 +79,26 @@ public void execute() throws Exception
String goldString = split[1];

if (isRegression) {
prop.setProperty(key,
predictionString + ";" + goldString + ";" + THRESHOLD_CONSTANT);
sb.append(key + "=" + predictionString + ";" + goldString + ";" + THRESHOLD_CONSTANT
+ "\n");
}
else {
int pred = Double.valueOf(predictionString).intValue();
int gold = Double.valueOf(goldString).intValue();
prop.setProperty(key, pred + ";" + gold + ";" + THRESHOLD_CONSTANT);
sb.append(key + "=" + pred + ";" + gold + ";" + THRESHOLD_CONSTANT
+ "\n");
}
lineCounter++;
}

File targetFile = getTargetOutputFile();

FileWriterWithEncoding fw = null;
try {
fw = new FileWriterWithEncoding(targetFile, "utf-8");
prop.store(fw, header);
}
finally {
IOUtils.closeQuietly(fw);
}

DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss");
Calendar cal = Calendar.getInstance();
String timeStamp = dateFormat.format(cal.getTime());

String content = header + "\n#" + timeStamp + "\n" + sb.toString();
FileUtils.writeStringToFile(targetFile, content, "utf-8");

}

Expand Down Expand Up @@ -185,7 +183,7 @@ private String buildHeader(Map<Integer, String> id2label, boolean isRegression)
throws UnsupportedEncodingException
{
StringBuilder header = new StringBuilder();
header.append("ID=PREDICTION;GOLDSTANDARD;THRESHOLD" + "\n" + "labels" + " ");
header.append("#ID=PREDICTION;GOLDSTANDARD;THRESHOLD" + "\n#" + "labels" + " ");

if (isRegression) {
// no label mapping for regression so that is all we have to do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,19 @@
import java.io.ObjectInputStream;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.Calendar;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.output.FileWriterWithEncoding;
import org.apache.commons.lang.StringUtils;
import org.dkpro.lab.engine.TaskContext;
import org.dkpro.lab.storage.StorageService.AccessMode;
import org.dkpro.tc.ml.report.TcBatchReportBase;
import org.dkpro.tc.ml.report.util.SortedKeyProperties;
import org.dkpro.tc.ml.weka.core._eka;
import org.dkpro.tc.ml.weka.task.WekaOutcomeHarmonizer;
import org.dkpro.tc.ml.weka.task.WekaTestTask;
Expand Down Expand Up @@ -89,25 +88,25 @@ public void execute() throws Exception

List<String> labels = getLabels(isMultiLabel, isRegression);

Properties props;

String content;
if (isMultiLabel) {
MultilabelResult r = readMlResultFromFile(mlResults);
props = generateMlProperties(predictions, labels, r);
content = generateMlProperties(predictions, labels, r);
}
else {
Map<Integer, String> documentIdMap = loadDocumentMap();
props = generateSlProperties(predictions, isRegression, isUnit, documentIdMap, labels);
content = generateSlProperties(predictions, isRegression, isUnit, documentIdMap, labels);
}

FileWriterWithEncoding fw = null;
try {
fw = new FileWriterWithEncoding(getTargetOutputFile(), "utf-8");
props.store(fw, generateHeader(labels));
}
finally {
IOUtils.closeQuietly(fw);
}
String header = generateHeader(labels);

DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss");
Calendar cal = Calendar.getInstance();
String timeStamp = dateFormat.format(cal.getTime());

String data = header + "\n#" + timeStamp + "\n" + content;

FileUtils.writeStringToFile(getTargetOutputFile(), data, "utf-8");
}

protected File getTargetOutputFile()
Expand All @@ -133,8 +132,8 @@ private List<String> getLabels(boolean multiLabel, boolean regression) throws IO
protected static String generateHeader(List<String> labels) throws UnsupportedEncodingException
{
StringBuilder comment = new StringBuilder();
comment.append("ID=PREDICTION" + SEPARATOR_CHAR + "GOLDSTANDARD" + SEPARATOR_CHAR
+ "THRESHOLD" + "\n" + "labels");
comment.append("#ID=PREDICTION" + SEPARATOR_CHAR + "GOLDSTANDARD" + SEPARATOR_CHAR
+ "THRESHOLD" + "\n#" + "labels");

// add numbered indexing of labels: e.g. 0=NPg, 1=JJ
for (int i = 0; i < labels.size(); i++) {
Expand All @@ -144,12 +143,11 @@ protected static String generateHeader(List<String> labels) throws UnsupportedEn
return comment.toString();
}

protected static Properties generateMlProperties(Instances predictions, List<String> labels,
protected static String generateMlProperties(Instances predictions, List<String> labels,
MultilabelResult r)
throws ClassNotFoundException, IOException
{
Properties props = new SortedKeyProperties();

StringBuilder sb = new StringBuilder();
int attOffset = predictions.attribute(ID_FEATURE_NAME).index();

Map<String, Integer> class2number = classNamesToMapping(labels);
Expand All @@ -168,17 +166,16 @@ protected static Properties generateMlProperties(Instances predictions, List<Str
String s = (StringUtils.join(predList, ",") + SEPARATOR_CHAR
+ StringUtils.join(goldList, ",") + SEPARATOR_CHAR + bipartition);
String stringValue = predictions.get(i).stringValue(attOffset);
props.setProperty(stringValue, s);
sb.append(stringValue + "=" + s+"\n");
}
return props;
return sb.toString();
}

protected Properties generateSlProperties(Instances predictions, boolean isRegression,
protected String generateSlProperties(Instances predictions, boolean isRegression,
boolean isUnit, Map<Integer, String> documentIdMap, List<String> labels)
throws Exception
{

Properties props = new SortedKeyProperties();
String[] classValues = new String[predictions.numClasses()];

for (int i = 0; i < predictions.numClasses(); i++) {
Expand All @@ -189,6 +186,8 @@ protected Properties generateSlProperties(Instances predictions, boolean isRegre

prepareBaseline();

StringBuilder sb = new StringBuilder();

int idx = 0;
for (Instance inst : predictions) {
Double gold;
Expand All @@ -208,24 +207,22 @@ protected Properties generateSlProperties(Instances predictions, boolean isRegre
// .get(gsAtt.value(prediction.intValue()));
Integer goldAsNumber = class2number.get(classValues[gold.intValue()]);

String stringValue = inst.stringValue(attOffset);
String key = inst.stringValue(attOffset);
if (!isUnit && documentIdMap != null) {
stringValue = documentIdMap.get(idx++);
key = documentIdMap.get(idx++);
}
props.setProperty(stringValue, getPrediction(prediction, class2number, gsAtt)
+ SEPARATOR_CHAR + goldAsNumber + SEPARATOR_CHAR + String.valueOf(-1));
sb.append(key + "=" + getPrediction(prediction, class2number, gsAtt) + SEPARATOR_CHAR + goldAsNumber + SEPARATOR_CHAR + String.valueOf(-1) + "\n");
}
else {
// the outcome is numeric
String stringValue = inst.stringValue(attOffset);
String key = inst.stringValue(attOffset);
if (documentIdMap != null) {
stringValue = documentIdMap.get(idx++);
key = documentIdMap.get(idx++);
}
props.setProperty(stringValue,
prediction + SEPARATOR_CHAR + gold + SEPARATOR_CHAR + String.valueOf(0));
sb.append(key + "=" + prediction + SEPARATOR_CHAR + gold + SEPARATOR_CHAR + "-1" + "\n");
}
}
return props;
return sb.toString();
}

protected String getPrediction(Double prediction, Map<String, Integer> class2number,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,25 @@
package org.dkpro.tc.ml.report.deeplearning;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.StringUtils;
import org.dkpro.lab.storage.StorageService.AccessMode;
import org.dkpro.tc.core.Constants;
import org.dkpro.tc.core.DeepLearningConstants;
import org.dkpro.tc.core.ml.TcDeepLearningAdapter;
import org.dkpro.tc.ml.report.TcBatchReportBase;
import org.dkpro.tc.ml.report.util.SortedKeyProperties;

public class DeepLearningId2OutcomeReport
extends TcBatchReportBase
Expand Down Expand Up @@ -76,7 +74,7 @@ public void execute() throws Exception
Map<String, String> inverseMap = inverseMap(map);

StringBuilder header = new StringBuilder();
header.append("ID=PREDICTION;GOLDSTANDARD;THRESHOLD\nlabels ");
header.append("#ID=PREDICTION;GOLDSTANDARD;THRESHOLD\n#labels ");

List<String> k = new ArrayList<>(map.keySet());
for (Integer i = 0; i < map.keySet().size(); i++) {
Expand All @@ -94,8 +92,8 @@ public void execute() throws Exception
}

List<String> nameOfTargets = getNameOfTargets();
Properties prop = new SortedKeyProperties();

StringBuilder sb = new StringBuilder();
int shift = 0;
for (int i = 0; i < predictions.size(); i++) {

Expand All @@ -115,7 +113,7 @@ public void execute() throws Exception
String[] split = p.split("\t");

if (isMultiLabel) {
multilabelReport(id, split, isIntegerMode, prop, map);
sb = multilabelReport(id, split, isIntegerMode, sb, map);
continue;
}

Expand All @@ -136,19 +134,16 @@ public void execute() throws Exception
prediction = map.get(split[1]).toString();
}
}
prop.setProperty("" + id,
prediction + SEPARATOR_CHAR + gold + SEPARATOR_CHAR + THRESHOLD);
sb.append(id + "=" + prediction + SEPARATOR_CHAR + gold + SEPARATOR_CHAR + THRESHOLD + "\n");
}

DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss");
Calendar cal = Calendar.getInstance();
String timeStamp = dateFormat.format(cal.getTime());

String data = header.toString() + "\n#" + timeStamp + "\n" + sb.toString();
File id2o = getTargetFile();
OutputStreamWriter osw = null;
try {
osw = new OutputStreamWriter(new FileOutputStream(id2o), "utf-8");
prop.store(osw, header.toString());
}
finally {
IOUtils.closeQuietly(osw);
}
FileUtils.writeStringToFile(id2o, data, "utf-8");
}

private String determineId(List<String> nameOfTargets, int i, int shift)
Expand Down Expand Up @@ -195,7 +190,7 @@ private Map<String, String> inverseMap(Map<String, String> map)
return inverseMap;
}

private void multilabelReport(String id, String[] split, boolean isIntegerMode, Properties prop,
private StringBuilder multilabelReport(String id, String[] split, boolean isIntegerMode, StringBuilder sb,
Map<String, String> map)
{

Expand All @@ -215,7 +210,8 @@ private void multilabelReport(String id, String[] split, boolean isIntegerMode,
s = split[1].split(" ");
prediction = label2String(s, map);
}
prop.setProperty("" + id, prediction + SEPARATOR_CHAR + gold + SEPARATOR_CHAR + THRESHOLD);
sb.append(id + "=" + prediction + SEPARATOR_CHAR + gold + SEPARATOR_CHAR + THRESHOLD + "\n");
return sb;
}

private String label2String(String[] val, Map<String, String> map)
Expand Down

0 comments on commit e512938

Please sign in to comment.