Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference and minor improvements #16

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ To evaluate a saved model on the test set:
```
./bin/eval-cnn.sh conf/conll/dilated-cnn.conf test --load_model path/to/model
```

Inference
----
To save the predictions on the validation set:
```
./bin/predict-cnn.sh conf/conll/dilated-cnn.conf --load_model path/to/model
```
To save the predictions on the test set:
```
./bin/predict-cnn.sh conf/conll/dilated-cnn.conf test --load_model path/to/model
```

The predictions currently contain <OOV> (out-of-vocabulary) tokens for words that are not present in the training vocabulary.


Configs
Expand Down
28 changes: 28 additions & 0 deletions bin/predict-cnn.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash

conf=$1
if [ ! -e $conf ]; then
echo "No config file specified; Exiting."
exit 1
fi
source $conf

additional_args=${@:2}

if [[ "$2" == "test" ]]; then
dev_dir=$test_dir
additional_args=${@:3}
fi

# star escaping
dev_fixed=`echo "$dev_dir" | sed 's/\*/\\\*/'`

cmd="$DILATED_CNN_NER_ROOT/bin/train-cnn.sh \
$conf \
--predict_only \
--load_dir $model_dir \
--dev_dir $dev_fixed \
$additional_args"

echo ${cmd}
eval ${cmd}
38 changes: 36 additions & 2 deletions src/eval_f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import numpy as np
import sys
from tqdm import tqdm


def is_start(curr):
Expand Down Expand Up @@ -37,7 +38,7 @@ def print_context(width, start, tok_list, pred_list, gold_list):
correct_counts = {t: 0 for t in label_map.values()}
token_count = 0
# iterate over batches
for predictions, (dev_label_batch, dev_token_batch, _, _, dev_seq_len_batch, _, _) in zip(predictions, batches):
for predictions, (dev_label_batch, dev_token_batch, _, _, dev_seq_len_batch, _, _) in tqdm(zip(predictions, batches)):
# iterate over examples in batch
for preds, labels, tokens, seq_lens in zip(predictions, dev_label_batch, dev_token_batch, dev_seq_len_batch):
start = pad_width
Expand Down Expand Up @@ -105,7 +106,7 @@ def print_context(width, start, tok_list, pred_list, gold_list):

accuracy = all_correct / all_gold

print("\t%10s\tPrec\tRecall" % ("F1"))
print("\t%10s\tPrec\tRecall" % "F1")
print("%10s\t%2.2f\t%2.2f\t%2.2f" % ("Micro (Seg)", f1_micro * 100, precision_micro * 100, recall_micro * 100))
print("%10s\t%2.2f\t%2.2f\t%2.2f" % ("Macro (Seg)", f1_macro * 100, precision_macro * 100, recall_macro * 100))
print("-------")
Expand All @@ -118,6 +119,39 @@ def print_context(width, start, tok_list, pred_list, gold_list):
return f1_micro, precision_micro


def segment_inference(batches, predictions, label_map, type_int_int_map, labels_id_str_map, vocab_id_str_map, outside_idx, pad_width, start_end, extra_text="", verbose=False):
if extra_text != "":
print(extra_text)

def print_context(width, start, tok_list, pred_list, gold_list):
for offset in range(-width, width+1):
idx = offset + start
if 0 <= idx < len(tok_list):
print("%s\t%s\t%s" % (vocab_id_str_map[tok_list[idx]], labels_id_str_map[pred_list[idx]], labels_id_str_map[gold_list[idx]]))
print()

# flush the file
with open("predictions.txt", 'w'):
pass

# iterate over batches
for predictions, (dev_label_batch, dev_token_batch, _, _, dev_seq_len_batch, _, _) in tqdm(zip(predictions, batches)):
# iterate over examples in batch
for preds, labels, tokens, seq_lens in zip(predictions, dev_label_batch, dev_token_batch, dev_seq_len_batch):
start = pad_width
for seq_len in seq_lens:
predicted = preds[start:seq_len+start]
toks = tokens[start:seq_len+start]

with open("predictions.txt", "a") as f:
for word_id, label_id in zip(toks, predicted):
# TODO: Deal with <OOV>
f.write("%s %s\n" % (vocab_id_str_map[word_id], labels_id_str_map[label_id]))
f.write("\n")

start += seq_len + (2 if start_end else 1) * pad_width


def print_training_error(num_examples, start_time, epoch_losses, step):
losses_str = ' '.join(["%5.5f"]*len(epoch_losses)) % tuple(map(lambda l: l/step, epoch_losses))
print("%20d examples at %5.2f examples/sec. Error: %s" %
Expand Down
36 changes: 24 additions & 12 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ def main(argv):
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
shape = variable.get_shape()
variable_parametes = 1
variable_parameters = 1
for dim in shape:
variable_parametes *= dim.value
total_parameters += variable_parametes
print("Total trainable parameters: %d" % (total_parameters))
variable_parameters *= dim.value
total_parameters += variable_parameters
print("Total trainable parameters: %d" % total_parameters)

if FLAGS.clip_norm > 0:
grads, _ = tf.clip_by_global_norm(tf.gradients(model.loss, model_vars), FLAGS.clip_norm)
Expand Down Expand Up @@ -276,13 +276,21 @@ def run_evaluation(eval_batches, extra_text=""):
evaluation.print_conlleval_format(FLAGS.print_preds, eval_batches, predictions, labels_id_str_map, vocab_id_str_map, pad_width)

# print evaluation
f1_micro, precision = evaluation.segment_eval(eval_batches, predictions, type_set, type_int_int_map,
if FLAGS.predict_only:
evaluation.segment_inference(eval_batches, predictions, type_set, type_int_int_map,
labels_id_str_map, vocab_id_str_map,
outside_idx=map(
lambda t: type_set[t] if t in type_set else
type_set["O"], outside_set),
pad_width=pad_width, start_end=FLAGS.start_end,
extra_text="Segment evaluation %s:" % extra_text)
else:
f1_micro, precision = evaluation.segment_eval(eval_batches, predictions, type_set, type_int_int_map,
labels_id_str_map, vocab_id_str_map,
outside_idx=map(lambda t: type_set[t] if t in type_set else type_set["O"], outside_set),
pad_width=pad_width, start_end=FLAGS.start_end,
extra_text="Segment evaluation %s:" % extra_text)

return f1_micro, precision
return f1_micro, precision

threads = tf.train.start_queue_runners(sess=sess)
log_every = int(max(100, num_train_examples / 5))
Expand Down Expand Up @@ -463,7 +471,7 @@ def train(max_epochs, best_score, model_hidden_drop, model_input_drop, until_con
train_batcher._step += 1
return best_score, training_iteration, speed_num/speed_denom

if FLAGS.evaluate_only:
if FLAGS.evaluate_only or FLAGS.predict_only:
if FLAGS.train_eval:
run_evaluation(train_batches, "(train)")
print()
Expand All @@ -485,13 +493,16 @@ def train(max_epochs, best_score, model_hidden_drop, model_input_drop, until_con
sv.coord.join(threads)
sess.close()

total_time = time.time()-training_start_time
total_time = time.time() - training_start_time
if FLAGS.evaluate_only:
print("Testing time: %d seconds" % (total_time))
print("Testing time: %d seconds" % total_time)
elif FLAGS.predict_only:
print("Inference time: %d seconds" % total_time)
else:
print("Training time: %d minutes, %d iterations (%3.2f minutes/iteration)" % (total_time/60, training_iteration, total_time/(60*training_iteration)))
print("Avg training speed: %f examples/second" % (train_speed))
print("Best dev F1: %2.2f" % (best_score*100))
print("Avg training speed: %f examples/second" % train_speed)
print("Best dev F1: %2.2f" % best_score * 100.)


if __name__ == '__main__':
tf.app.flags.DEFINE_string('train_dir', '', 'directory containing preprocessed training data')
Expand Down Expand Up @@ -537,6 +548,7 @@ def train(max_epochs, best_score, model_hidden_drop, model_input_drop, until_con
tf.app.flags.DEFINE_string('nonlinearity', 'relu', 'nonlinearity function to use (tanh, sigmoid, relu)')
tf.app.flags.DEFINE_boolean('until_convergence', False, 'whether to run until convergence')
tf.app.flags.DEFINE_boolean('evaluate_only', False, 'whether to only run evaluation')
tf.app.flags.DEFINE_boolean('predict_only', False, 'whether to only run prediction')
tf.app.flags.DEFINE_string('layers', '', 'json definition of layers (dilation, filters, width)')
tf.app.flags.DEFINE_string('print_preds', '', 'print out predictions (for conll eval script) to given file (or do not if empty)')
tf.app.flags.DEFINE_boolean('viterbi', False, 'whether to use viberbi inference')
Expand Down