-
Notifications
You must be signed in to change notification settings - Fork 3
/
makedata.py
159 lines (122 loc) · 4.5 KB
/
makedata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
from itertools import izip
import tensorflow as tf
import re
import os
import pickle as pkl
flags = tf.flags
flags.DEFINE_string('dst_dir', 'data', 'Directory to write to.')
flags.DEFINE_string('src_dir', 'prep', 'Directory containing preprocessed data.')
flags.DEFINE_integer(
'threshold', 3,
'Remove tokens appeared less than `threshold` number of times.')
FLAGS = flags.FLAGS
PAD = '<pad>'
UNK = '<unk>'
EOS = '</s>'
def cleanup_sentence(s):
s = re.sub(r'\t', '', s)
# remove leading and following white spaces
s = s.strip()
# convert multiple spaces into a single space: this is needed to
# make the following pl.utils.split() function return only words
# and not white spaces
s = re.sub(r'%s+', ' ', s)
return s
def build_dictionary(filename, threshold):
token_to_freq = defaultdict(int)
print('[ Reading from', filename, ']')
with open(filename, 'r') as f:
for line in f:
words = cleanup_sentence(line).split(' ')
for w in words:
token_to_freq[w] += 1
vocab_list = [PAD, UNK, EOS]
for token, freq in token_to_freq.iteritems():
if freq >= threshold:
vocab_list.append(token)
print('[ Done making the dictionary. ]')
print('Training corpus statistics')
print('Unique words:', len(token_to_freq))
print('Total words', sum(token_to_freq.values()))
print('[ There are effectively', len(vocab_list), 'words in the corpus. ]')
dictionary = {w : i for i, w in enumerate(vocab_list)}
return dictionary
def words_to_ids(words, dictionary):
ids = []
unk_count = 0
for w in words:
if w not in dictionary:
ids.append(dictionary[UNK])
unk_count += 1
else:
ids.append(dictionary[w])
return ids, unk_count, len(ids)
def build_tfrecords(src_dict, src_fname, target_dict, target_fname,
output_fname):
src_unk_count = 0
src_tokens_count = 0
target_unk_count = 0
target_tokens_count = 0
lines_count = 0
writer = tf.python_io.TFRecordWriter(output_fname)
with open(src_fname, 'r') as fsrc, open(target_fname, 'r') as ftarget:
for src, target in izip(fsrc, ftarget):
lines_count += 1
src = cleanup_sentence(src)
target = cleanup_sentence(target)
src_tokens = src.split() + [EOS, ]
src_ids, unk, tok = words_to_ids(src_tokens, src_dict)
src_unk_count += unk
src_tokens_count += tok
target_tokens = [EOS, ] + target.split()
target_ids, unk, tok = words_to_ids(target_tokens, target_dict)
target_unk_count += unk
target_tokens_count += tok
example = tf.train.Example(
features=tf.train.Features(
feature={
'src': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[' '.join(src_tokens)])),
'target': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[' '.join(target_tokens)])),
'src_ids': tf.train.Feature(
int64_list=tf.train.Int64List(value=src_ids)),
'target_ids': tf.train.Feature(
int64_list=tf.train.Int64List(value=target_ids)),
}
)
)
serialized = example.SerializeToString()
writer.write(serialized)
print('-- %s stats:' % output_fname)
print('nlines: %d, ntokens (src: %d, tgt: %d); UNK (src: %.2f%%, tgt: %.2f%%)'
% (lines_count, src_tokens_count, target_tokens_count,
100 * src_unk_count / src_tokens_count,
100 * target_unk_count / target_tokens_count))
def main(unused_args):
if not os.path.exists(FLAGS.dst_dir):
os.makedirs(FLAGS.dst_dir)
datasets = {
'train': ('train.de-en.de', 'train.de-en.en'),
'valid': ('valid.de-en.de', 'valid.de-en.en'),
'test': ('test.de-en.de', 'test.de-en.en')
}
src_dict = build_dictionary(
os.path.join(FLAGS.src_dir, datasets['train'][0]), FLAGS.threshold)
target_dict = build_dictionary(
os.path.join(FLAGS.src_dir, datasets['train'][1]), FLAGS.threshold)
for split, fnames in datasets.iteritems():
build_tfrecords(
src_dict, os.path.join(FLAGS.src_dir, fnames[0]),
target_dict, os.path.join(FLAGS.src_dir, fnames[1]),
os.path.join(FLAGS.dst_dir, '%s.de-en.tfrecords' % split))
with open(os.path.join(FLAGS.dst_dir, 'dict.de-en.de.pkl'), 'wb') as f:
pkl.dump(src_dict, f)
with open(os.path.join(FLAGS.dst_dir, 'dict.de-en.en.pkl'), 'wb') as f:
pkl.dump(target_dict, f)
if __name__ == "__main__":
tf.app.run()