-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.py
158 lines (123 loc) · 4.9 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Prints to stdout different curriculum questions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import textwrap
# Dependency imports
from absl import app
from absl import flags
from absl import logging
import generate_settings
from modules import modules
import six
from six.moves import range
FLAGS = flags.FLAGS
flags.DEFINE_string('filter', '', 'restrict to matching module names')
flags.DEFINE_integer('per_train_module', 200000, 'Num of examples per train module')
flags.DEFINE_integer('per_test_module', 1000, 'Num of examples per test module')
flags.DEFINE_bool('show_dropped', False, 'Whether to print dropped questions')
filtered_modules = collections.OrderedDict([])
counts = {}
def _make_entropy_fn(level, num_levels):
"""This returns a function that returns a subrange of entropy.
E.g., if level=1 (medium) and num_levels=3, then the returned function will
map the range [x, x + y] to [x + y/3, x + 2y/3].
Args:
level: Integer in range [0, num_levels - 1].
num_levels: Number of difficulty levels.
Returns:
Function to restrict entropy range.
"""
lower = level / num_levels
upper = (level + 1) / num_levels
def modify_entropy(range_):
assert len(range_) == 2
length = range_[1] - range_[0]
return (range_[0] + lower * length, range_[0] + upper * length)
return modify_entropy
def _filter_and_flatten(modules_):
"""Returns flattened dict, filtered according to FLAGS."""
flat = collections.OrderedDict()
def add(submodules, prefix=None):
for key, module_or_function in six.iteritems(submodules):
full_name = prefix + '__' + key if prefix is not None else key
if isinstance(module_or_function, dict):
add(module_or_function, full_name)
else:
if FLAGS.filter not in full_name:
continue
flat[full_name] = module_or_function
add(modules_)
# Make sure list of modules are in deterministic order. This is important when
# generating across multiple machines.
flat = collections.OrderedDict(
[(key, flat[key]) for key in sorted(six.iterkeys(flat))])
return flat
def init_modules(train_split=False):
"""Inits the dicts containing functions for generating modules."""
if filtered_modules:
return # already initialized
all_modules = collections.OrderedDict([])
if train_split:
all_modules['train-easy'] = modules.train(_make_entropy_fn(0, 3))
all_modules['train-medium'] = modules.train(_make_entropy_fn(1, 3))
all_modules['train-hard'] = modules.train(_make_entropy_fn(2, 3))
else:
all_modules['train'] = modules.train(_make_entropy_fn(0, 1))
all_modules['interpolate'] = modules.test()
all_modules['extrapolate'] = modules.test_extra()
counts['train'] = FLAGS.per_train_module
counts['train-easy'] = FLAGS.per_train_module // 3
counts['train-medium'] = FLAGS.per_train_module // 3
counts['train-hard'] = FLAGS.per_train_module // 3
counts['interpolate'] = FLAGS.per_test_module
counts['extrapolate'] = FLAGS.per_test_module
for regime_, modules_ in six.iteritems(all_modules):
filtered_modules[regime_] = _filter_and_flatten(modules_)
def sample_from_module(module):
"""Samples a problem, ignoring samples with overly long questions / answers.
Args:
module: Callable returning a `Problem`.
Returns:
Pair `(problem, num_dropped)`, where `problem` is an instance of `Problem`
and `num_dropped` is an integer >= 0 indicating the number of samples that
were dropped.
"""
num_dropped = 0
while True:
problem = module()
question = str(problem.question)
if len(question) > generate_settings.MAX_QUESTION_LENGTH:
num_dropped += 1
if FLAGS.show_dropped:
logging.warning('Dropping question: %s', question)
continue
answer = str(problem.answer)
if len(answer) > generate_settings.MAX_ANSWER_LENGTH:
num_dropped += 1
if FLAGS.show_dropped:
logging.warning('Dropping question with answer: %s', answer)
continue
return problem, num_dropped
def main(unused_argv):
"""Prints Q&As from modules according to FLAGS.filter."""
init_modules()
text_wrapper = textwrap.TextWrapper(
width=80, initial_indent=' ', subsequent_indent=' ')
for regime, flat_modules in six.iteritems(filtered_modules):
per_module = counts[regime]
for module_name, module in six.iteritems(flat_modules):
# These magic print constants make the header bold.
print('\033[1m{}/{}\033[0m'.format(regime, module_name))
num_dropped = 0
for _ in range(per_module):
problem, extra_dropped = sample_from_module(module)
num_dropped += extra_dropped
text = text_wrapper.fill(
'{} \033[92m{}\033[0m'.format(problem.question, problem.answer))
print(text)
if num_dropped > 0:
logging.warning('Dropped %d examples', num_dropped)
if __name__ == '__main__':
app.run(main)