diff --git a/classification/dataloader.py b/classification/dataloader.py index 6c7f6ff..f1d048c 100644 --- a/classification/dataloader.py +++ b/classification/dataloader.py @@ -151,13 +151,14 @@ def pad(sequences, sos=None, eos=None, pad_token='', pad_left=True, reverse def create_one_batch(x, y, map2id, oov='', gpu=False, sos=None, eos=None, bidirectional=False): oov_id = map2id[oov] - x_fwd = pad(x, sos=sos, eos=eos, pad_left=True) - length = len(x_fwd[0]) - batch_size = len(x_fwd) - x_fwd = [ map2id.get(w, oov_id) for seq in x_fwd for w in seq ] + x_padded = pad(x, sos=sos, eos=eos, pad_left=True) + length = len(x_padded[0]) + batch_size = len(x_padded) + x_fwd = [ map2id.get(w, oov_id) for seq in x_padded for w in seq ] x_fwd = torch.LongTensor(x_fwd) assert x_fwd.size(0) == length*batch_size x_fwd, y = x_fwd.view(batch_size, length).t().contiguous(), torch.LongTensor(y) + if gpu: x_fwd, y = x_fwd.cuda(), y.cuda() if bidirectional: @@ -170,12 +171,13 @@ def create_one_batch(x, y, map2id, oov='', gpu=False, if gpu: x_bwd = x_bwd.cuda() return (x_fwd, x_bwd), y - return (x_fwd), y + + return (x_fwd), y, x_padded # shuffle training examples and create mini-batches def create_batches(x, y, batch_size, map2id, perm=None, sort=False, gpu=False, - sos=None, eos=None, bidirectional=False): + sos=None, eos=None, bidirectional=False, get_text_batches=False): lst = perm or list(range(len(x))) # sort sequences based on their length; necessary for SST @@ -185,28 +187,39 @@ def create_batches(x, y, batch_size, map2id, perm=None, sort=False, gpu=False, x = [ x[i] for i in lst ] y = [ y[i] for i in lst ] + txt_batches = None + if get_text_batches: + txt_batches = [] + sum_len = 0.0 batches_x = [ ] batches_y = [ ] + + size = batch_size nbatch = (len(x)-1) // size + 1 for i in range(nbatch): - bx, by = create_one_batch(x[i*size:(i+1)*size], y[i*size:(i+1)*size], + bx, by, padded_x = create_one_batch(x[i*size:(i+1)*size], y[i*size:(i+1)*size], map2id, gpu=gpu, sos=sos, eos=eos, bidirectional=bidirectional) sum_len += len(bx[0]) batches_x.append(bx) batches_y.append(by) + if get_text_batches: + txt_batches.append(padded_x) + if sort: perm = list(range(nbatch)) random.shuffle(perm) batches_x = [ batches_x[i] for i in perm ] batches_y = [ batches_y[i] for i in perm ] + if get_text_batches: + txt_batches = [txt_batches[i] for i in perm] # sys.stdout.write("{} batches, avg len: {:.1f}\n".format( # nbatch, sum_len/nbatch # )) - return batches_x, batches_y + return batches_x, batches_y, txt_batches def load_embedding_npz(path): diff --git a/classification/experiment_params.py b/classification/experiment_params.py new file mode 100644 index 0000000..fddfc29 --- /dev/null +++ b/classification/experiment_params.py @@ -0,0 +1,153 @@ +# these egories have more than 100 training instances. +def get_categories(): + #return ["apparel/", "automotive/", "baby/", "beauty/", "books/", "camera_&_photo/", "cell_phones_&_service/", "computer_&_video_games/", "dvd/", "electronics/", "gourmet_food/", "grocery/", "health_&_personal_care/", "jewelry_&_watches/", "kitchen_&_housewares/", "magazines/", "music/", "outdoor_living/", "software/", "sports_&_outdoors/", "toys_&_games/", "video/"] + #return ["apparel/", "baby/", "beauty/", "books/", "camera_&_photo/", "cell_phones_&_service/", "computer_&_video_games/", "dvd/", "electronics/", "health_&_personal_care/", "kitchen_&_housewares/", "magazines/", "music/", "software/", "sports_&_outdoors/", "toys_&_games/", "video/"] + #return ["camera_&_photo/","apparel/","health_&_personal_care/", "toys_&_games/", "kitchen_&_housewares/", "dvd/","books/", "original_mix/"] + + #return ["kitchen_&_housewares/","dvd/", "books/", "original_mix/"] + #return ["dvd/","original_mix/"] + #return ["kitchen_&_housewares/", "books/"] + #return ["kitchen_&_housewares/"] + return ["books/"] + + + + +class ExperimentParams: + def __init__(self, + path = None, + embedding = None, + loaded_embedding = None, + seed = 314159, + model = "rrnn", + semiring = "plus_times", + use_layer_norm = False, + use_output_gate = False, + use_rho = True, + rho_sum_to_one = False, + use_last_cs = False, + use_epsilon_steps = False, + pattern = "2-gram", + activation = "none", + trainer = "adam", + fix_embedding = True, + batch_size = 64, + max_epoch=500, + d_out="256", + dropout=0.2, + embed_dropout=0.2, + rnn_dropout=0.2, + depth=1, + lr=0.001, + lr_decay=0, + lr_schedule_decay=0.5, + gpu=True, + eval_ite=50, + patience=30, + lr_patience=10, + weight_decay=1e-6, + clip_grad=5, + reg_strength=0, + reg_strength_multiple_of_loss=0, + reg_goal_params=False, + prox_step=False, + num_epochs_debug=-1, + debug_run = False, + sparsity_type="none", + filename_prefix="", + filename_suffix="", + dataset="amazon/", + learned_structure=False, + logging_dir="/home/jessedd/projects/rational-recurrences/classification/logging/", + base_data_dir="/home/jessedd/data/", + output_dir=None, + input_model=None + ): + self.path = path + self.embedding = embedding + self.loaded_embedding = loaded_embedding + self.seed = seed + self.model = model + self.semiring = semiring + self.use_layer_norm = use_layer_norm + self.use_output_gate = use_output_gate + self.use_rho = use_rho + self.rho_sum_to_one = rho_sum_to_one + self.use_last_cs = use_last_cs + self.use_epsilon_steps = use_epsilon_steps + self.pattern = pattern + self.activation = activation + self.trainer = trainer + self.fix_embedding = fix_embedding + self.batch_size = batch_size + self.max_epoch = max_epoch + self.d_out = d_out + self.dropout = dropout + self.embed_dropout = embed_dropout + self.rnn_dropout = rnn_dropout + self.depth = depth + self.lr = lr + self.lr_decay = lr_decay + self.lr_schedule_decay = lr_schedule_decay + self.gpu = gpu + self.eval_ite = eval_ite + self.patience = patience + self.lr_patience = lr_patience + self.weight_decay = weight_decay + self.clip_grad = clip_grad + self.reg_strength = reg_strength + self.reg_strength_multiple_of_loss = reg_strength_multiple_of_loss + self.reg_goal_params = reg_goal_params + self.prox_step = prox_step + self.num_epochs_debug = num_epochs_debug + self.debug_run = debug_run + self.sparsity_type = sparsity_type + self.filename_prefix = filename_prefix + self.filename_suffix = filename_suffix + self.dataset = dataset + self.learned_structure = learned_structure + self.logging_dir = logging_dir + self.base_data_dir = base_data_dir + self.output_dir = output_dir + self.input_model = input_model + + self.current_experiment() + + # adjusts the default values with the current experiment + def current_experiment(self): + base_data_dir = self.base_data_dir + if self.debug_run: + base_data_dir += "amazon_debug/" + else: + base_data_dir += self.dataset + self.path = base_data_dir + self.embedding = base_data_dir + "embedding" + + def filename(self): + if self.sparsity_type == "none" and self.learned_structure: + sparsity_name = self.learned_structure + else: + sparsity_name = self.sparsity_type + if self.debug_run: + self.filename_prefix += "DEBUG_" + name = "{}{}_layers={}_lr={:.3E}_dout={}_drout={:.4f}_rnndout={:.4f}_embdout={:.4f}_wdecay={:.2E}_clip={:.2f}_pattern={}_sparsity={}".format( + self.filename_prefix, self.trainer, self.depth, self.lr, self.d_out, self.dropout, self.rnn_dropout, self.embed_dropout, + self.weight_decay, self.clip_grad, self.pattern, sparsity_name) + if self.reg_strength > 0: + name += "_regstr={:.3E}".format(self.reg_strength) + if self.reg_strength_multiple_of_loss: + name += "_regstrmultofloss={}".format(self.reg_strength_multiple_of_loss) + if self.reg_goal_params: + name += "_goalparams={}".format(self.reg_goal_params) + if self.prox_step: + name += "_prox" + if self.filename_suffix != "": + name += self.filename_suffix + if not self.gpu: + name = name + "_cpu" + + return name + + def __str__(self): + return str(vars(self)) + diff --git a/classification/experiment_tools.py b/classification/experiment_tools.py new file mode 100644 index 0000000..887c79d --- /dev/null +++ b/classification/experiment_tools.py @@ -0,0 +1,38 @@ +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter +import time +import os + + + +def preload_embed(dir_location): + start = time.time() + import dataloader + embs = dataloader.load_embedding(os.path.join(dir_location,"embedding_filtered")) + print("took {} seconds".format(time.time()-start)) + print("preloaded embeddings from amazon dataset.") + print("") + return embs + + +def general_arg_parser(): + """ CLI args related to training and testing models. """ + p = ArgumentParser(add_help=False) + p.add_argument("-d", '--base_dir', help="Data directory", type=str, required=True) + p.add_argument("-a", "--dataset", help="Dataset name", type=str, required=True) + p.add_argument("-p", "--pattern", help="Pattern specification", type=str, default="1-gram,2-gram,3-gram,4-gram") + p.add_argument("--d_out", help="Output dimension(?)", type=str, default="0,4,0,2") + p.add_argument("-g", "--gpu", help="Use GPU", action='store_true') + p.add_argument('--depth', help="Depth of network", type=int, default=1) + p.add_argument("-s", "--seed", help="Random seed", type=int, default=1234) + p.add_argument("-b", "--batch_size", help="Batch size", type=int, default=64) + p.add_argument("--use_last_cs", help="Only use last hidden state as output value", action='store_true') + + # p.add_argument("--max_doc_len", + # help="Maximum doc length. For longer documents, spans of length max_doc_len will be randomly " + # "selected each iteration (-1 means no restriction)", + # type=int, default=-1) + # p.add_argument("-n", "--num_train_instances", help="Number of training instances", type=int, default=None) + # p.add_argument("-e", "--embedding_file", help="Word embedding file", required=True) + + return p + diff --git a/classification/load_learned_structure.py b/classification/load_learned_structure.py new file mode 100644 index 0000000..38df720 --- /dev/null +++ b/classification/load_learned_structure.py @@ -0,0 +1,182 @@ +import sys +from experiment_params import ExperimentParams +import numpy as np +import glob + +np.set_printoptions(edgeitems=3,infstr='inf', + linewidth=9999, nanstr='nan', precision=5, + suppress=True, threshold=1000, formatter=None) + +def main(): + l1_or_entropy = "l1" + + if l1_or_entropy == "l1": + l1_example() + elif l1_or_entropy == "entropy": + entropy_example() + + +# the next few functions are for l1-regularized models, below that are the entropy regularization models + +def l1_example(): + file_base = "/home/jessedd/projects/rational-recurrences/classification/logging/amazon_categories/" + "books/" + file_base += "all_cs_and_equal_rho/hparam_opt/structure_search/add_reg_term_to_loss/" + filename_endings = ["*sparsity=states*goalparams=80*"] + for filename_ending in filename_endings: + filenames = glob.glob(file_base + filename_ending) + + for filename in filenames: + from_file(filename=filename) + +def l1_group_norms(args = None, filename = None, prox = False): + norms, best_valid = get_norms(args, filename) + + if not prox: + threshold = 0.1 + else: + threshold = 0.0001 + #threshold = min(norms[-1][:,0]) + learned_ngrams = norms > threshold + ngram_counts = [0] * (len(learned_ngrams[0]) + 1) + weirdos = [] + + for i in range(len(learned_ngrams)): + cur_ngram = 0 + cur_weird = False + for j in range(len(learned_ngrams[i])): + if cur_ngram == j and learned_ngrams[i][j]: + cur_ngram += 1 + elif cur_ngram == j and not learned_ngrams[i][j]: + continue + elif cur_ngram != j and learned_ngrams[i][j]: + cur_weird = True + elif cur_ngram != j and not learned_ngrams[i][j]: + continue + if cur_weird: + weirdos.append(learned_ngrams[i]) + + ngram_counts[cur_ngram] += 1 + total_params = ngram_counts[1] + 2*ngram_counts[2] + 3*ngram_counts[3] + 4*ngram_counts[4] + print("0,1,2,3,4 grams: {}, total params: {}, num out of order: {}, {}".format(str(ngram_counts), total_params, len(weirdos), best_valid)) + + return "{},{},{},{}".format(ngram_counts[1], ngram_counts[2], ngram_counts[3], ngram_counts[4]), total_params + + +def get_norms(args, filename): + if args: + path = "/home/jessedd/projects/rational-recurrences/classification/logging/" + args.dataset + path += args.filename() + ".txt" + else: + path = filename + + best_valid = None + lines = [] + with open(path, "r") as f: + lines = f.readlines() + + if "sparsity=wfsa" in path: + vals = [] + for line in lines: + try: + vals.append(float(line)) + except: + continue + elif "sparsity=edges" in path or "sparsity=states" in path: + + if "sparsity=edges" in path: + len_groups = 8 + else: + len_groups = 4 + + vals = [] + prev_line_was_data = False + wfsas = [] + for line in lines: + if "best_valid" in line: + best_valid = line.strip() + + split_line = [x for x in line.split(" ") if x != ''] + + if len(split_line) != len_groups and prev_line_was_data: + prev_line_was_data = False + vals.append(wfsas) + + wfsas = [] + else: + edges = [] + for item in split_line: + try: + edges.append(float(item)) + except: + continue + if len(edges) == len_groups: + prev_line_was_data = True + wfsas.append(edges) + + vals = vals[-1] + vals = np.asarray(vals) + + assert vals.shape[0] == 24 # this is the number of WFSAs in the model + assert vals.shape[1] == len_groups # this is the number of edges in each WFSA + + return vals, best_valid + + + + + +# these functions are for loading the rhos from entropy regularized models + + +def entropy_example(): + file_base = "/home/jessedd/projects/rational-recurrences/classification/logging/amazon/" + file_name = "norms_adam_layers=1_lr=0.0010000_regstr=0.0100000_dout=256_dropout=0.2_pattern=4-gram_sparsity=rho_entropy.txt" + from_file(file_base + file_name) + + +def entropy_rhos(file_loc, rho_bound): + backwards_lines = [] + with open(file_loc, "r") as f: + lines = f.readlines() + + found_var = False + for i in range(len(lines)): + if i == 0: + continue + if "Variable containing:" in lines[-i]: + break + if found_var: + backwards_lines.append(lines[-i].strip()) + if "[torch.cuda.FloatTensor of size" in lines[-i]: + found_var = True + + backwards_lines = backwards_lines[:len(backwards_lines)-1] + + + return extract_ngrams(backwards_lines, rho_bound) + + +def extract_ngrams(rhos, rho_bound): + ngram_counts = collections.Counter() + num_less_than_pointnine = 0 + for rho_line in rhos: + cur_rho_line = np.fromstring(rho_line, dtype=float, sep = " ") + if max(cur_rho_line) < rho_bound: + num_less_than_pointnine += 1 + cur_ngram = np.argmax(cur_rho_line) + ngram_counts[cur_ngram] = ngram_counts[cur_ngram] + 1 + + pattern = "" + d_out = "" + for i in range(4): + if ngram_counts[i] > 0: + pattern = pattern + "{}-gram,".format(i+1) + d_out = d_out + "{},".format(ngram_counts[i]) + pattern = pattern[:len(pattern)-1] + d_out = d_out[:len(d_out)-1] + + return pattern, d_out, num_less_than_pointnine * 1.0 / sum(ngram_counts.values()) + + +if __name__ == "__main__": + main() diff --git a/classification/modules.py b/classification/modules.py index d9bc3f2..8e34eac 100644 --- a/classification/modules.py +++ b/classification/modules.py @@ -32,7 +32,7 @@ def forward(self, x): class EmbeddingLayer(nn.Module): - def __init__(self, n_d, words, embs=None, fix_emb=True, sos='', eos='', + def __init__(self, words, embs=None, fix_emb=True, sos='', eos='', oov='', pad='', normalize=True): super(EmbeddingLayer, self).__init__() word2id = {} @@ -41,14 +41,9 @@ def __init__(self, n_d, words, embs=None, fix_emb=True, sos='', eos='', for word in embwords: assert word not in word2id, "Duplicate words in pre-trained embeddings" word2id[word] = len(word2id) - - sys.stdout.write("{} pre-trained word embeddings loaded.\n".format(len(word2id))) - if n_d != len(embvecs[0]): - sys.stdout.write("[WARNING] n_d ({}) != word vector size ({}). Use {} for embeddings.\n".format( - n_d, len(embvecs[0]), len(embvecs[0]) - )) - n_d = len(embvecs[0]) - + n_d = len(embvecs[0]) + sys.stdout.write("{} pre-trained word embeddings with dim={} loaded.\n".format(len(word2id), + n_d)) for w in deep_iter(words): if w not in word2id: word2id[w] = len(word2id) diff --git a/classification/regularization_search_experiments.py b/classification/regularization_search_experiments.py new file mode 100644 index 0000000..a20c5ef --- /dev/null +++ b/classification/regularization_search_experiments.py @@ -0,0 +1,186 @@ +import load_learned_structure +from run_current_experiment import get_k_sorted_hparams +from experiment_params import ExperimentParams +import train_classifier +import numpy as np +import time + + +def search_reg_str_entropy(cur_assignments, kwargs): + starting_reg_str = kwargs["reg_strength"] + file_base = "/home/jessedd/projects/rational-recurrences/classification/logging/" + kwargs["dataset"] + found_small_enough_reg_str = False + # first search by checking that after 5 epochs, more than half aren't above .9 + kwargs["max_epoch"] = 1 + counter = 0 + rho_bound = .99 + while not found_small_enough_reg_str: + counter += 1 + args = ExperimentParams(**kwargs, **cur_assignments) + cur_valid_err, cur_test_err = train_classifier.main(args) + + learned_pattern, learned_d_out, frac_under_pointnine = load_learned_structure.entropy_rhos( + file_base + args.filename() + ".txt", rho_bound) + print("fraction under {}: {}".format(rho_bound,frac_under_pointnine)) + print("") + if frac_under_pointnine < .25: + kwargs["reg_strength"] = kwargs["reg_strength"] / 2.0 + if kwargs["reg_strength"] < 10**-7: + kwargs["reg_strength"] = starting_reg_str + return counter, "too_big_lr" + else: + found_small_enough_reg_str = True + + found_large_enough_reg_str = False + kwargs["max_epoch"] = 5 + rho_bound = .9 + while not found_large_enough_reg_str: + counter += 1 + args = ExperimentParams(**kwargs, **cur_assignments) + cur_valid_err, cur_test_err = train_classifier.main(args) + + learned_pattern, learned_d_out, frac_under_pointnine = load_learned_structure.entropy_rhos( + file_base + args.filename() + ".txt", rho_bound) + print("fraction under {}: {}".format(rho_bound,frac_under_pointnine)) + print("") + if frac_under_pointnine > .25: + kwargs["reg_strength"] = kwargs["reg_strength"] * 2.0 + if kwargs["reg_strength"] > 10**4: + kwargs["reg_strength"] = starting_reg_str + return counter, "too_small_lr" + else: + found_large_enough_reg_str = True + # to set this back to the default + kwargs["max_epoch"] = 500 + return counter, "okay_lr" + +# ways this can fail: +# too small learning rate +# too large learning rate +# too large step size for reg strength, so it's too big then too small +def search_reg_str_l1(cur_assignments, kwargs): + # the final number of params is within this amount of target + smallest_reg_str = 10**-9 + largest_reg_str = 10**2 + distance_from_target = 10 + starting_reg_str = kwargs["reg_strength"] + found_good_reg_str = False + too_small = False + too_large = False + counter = 0 + reg_str_growth_rate = 2.0 + + while not found_good_reg_str: + counter += 1 + args = ExperimentParams(**kwargs, **cur_assignments) + cur_valid_err, cur_test_err = train_classifier.main(args) + learned_d_out, num_params = load_learned_structure.l1_group_norms(args=args, prox=kwargs["prox_step"]) + + if num_params < kwargs["reg_goal_params"] - distance_from_target: + if too_large: + # reduce size of steps for reg strength + reg_str_growth_rate = (reg_str_growth_rate + 1)/2.0 + too_large = False + too_small = True + kwargs["reg_strength"] = kwargs["reg_strength"] / reg_str_growth_rate + if kwargs["reg_strength"] < smallest_reg_str: + kwargs["reg_strength"] = starting_reg_str + return counter, "too_small_lr", cur_valid_err, learned_d_out + elif num_params > kwargs["reg_goal_params"] + distance_from_target: + if too_small: + # reduce size of steps for reg strength + reg_str_growth_rate = (reg_str_growth_rate + 1)/2.0 + too_small = False + too_large = True + kwargs["reg_strength"] = kwargs["reg_strength"] * reg_str_growth_rate + + if kwargs["reg_strength"] > largest_reg_str: + kwargs["reg_strength"] = starting_reg_str + + # it diverged, and for some reason the weights didn't drop + if num_params == int(args.d_out) * 4 and cur_assignments["lr"] > .25 and cur_valid_err > .3: + return counter, "too_big_lr", cur_valid_err, learned_d_out + else: + return counter, "too_small_lr", cur_valid_err, learned_d_out + else: + found_good_reg_str = True + return counter, "okay_lr", cur_valid_err, learned_d_out + +def train_k_then_l_models(k,l,counter,total_evals,start_time,**kwargs): + assert "reg_strength" in kwargs + if "prox_step" not in kwargs: + kwargs["prox_step"] = False + elif kwargs["prox_step"]: + assert False, "It's too unstable. books/all_cs_and_equal_rho/hparam_opt/structure_search/proximal_gradient too big then too small" + file_base = "/home/jessedd/projects/rational-recurrences/classification/logging/" + kwargs["dataset"] + best = { + "assignment" : None, + "valid_err" : 1, + "learned_pattern" : None, + "learned_d_out" : None, + "reg_strength": None + } + + reg_search_counters = [] + lr_lower_bound = 7*10**-3 + lr_upper_bound = 1.5 + all_assignments = get_k_sorted_hparams(k, lr_lower_bound, lr_upper_bound) + for i in range(len(all_assignments)): + + valid_assignment = False + while not valid_assignment: + cur_assignments = all_assignments[i] + if kwargs["sparsity_type"] == "rho_entropy": + one_search_counter, lr_judgement = search_reg_str_entropy(cur_assignments, kwargs) + elif kwargs["sparsity_type"] == "states": + one_search_counter, lr_judgement, cur_valid_err, learned_d_out = search_reg_str_l1( + cur_assignments, kwargs) + learned_pattern = "1-gram,2-gram,3-gram,4-gram" + + reg_search_counters.append(one_search_counter) + if lr_judgement == "okay_lr": + valid_assignment = True + else: + if lr_judgement == "too_big_lr": + # lower the upper bound + lr_upper_bound = cur_assignments['lr'] + reverse = True + elif lr_judgement == "too_small_lr": + # rase lower bound + lr_lower_bound = cur_assignments['lr'] + reverse = False + else: + assert False, "shouldn't be here." + new_assignments = get_k_sorted_hparams(k-i, lr_lower_bound, lr_upper_bound) + if reverse: + new_assignments.reverse() + all_assignments[i:len(all_assignments)] = new_assignments + + if kwargs["sparsity_type"] == "rho_entropy": + args = ExperimentParams(**kwargs, **cur_assignments) + cur_valid_err, cur_test_err = train_classifier.main(args) + + learned_pattern, learned_d_out, frac_under_pointnine = load_learned_structure.l1_group_norms( + file_base + args.filename() + ".txt", .9) + + if cur_valid_err < best["valid_err"]: + best = { + "assignment" : cur_assignments, + "valid_err" : cur_valid_err, + "learned_pattern" : learned_pattern, + "learned_d_out" : learned_d_out, + "reg_strength": kwargs["reg_strength"] + } + + counter[0] = counter[0] + 1 + print("trained {} out of {} hyperparameter assignments, so far {} seconds".format( + counter[0],total_evals, round(time.time()-start_time, 3))) + + kwargs["reg_strength"] = best["reg_strength"] + for i in range(l): + args = ExperimentParams(filename_suffix="_{}".format(i),**kwargs, **best["assignment"]) + cur_valid_err, cur_test_err = train_classifier.main(args) + counter[0] = counter[0] + 1 + + + return best, reg_search_counters diff --git a/classification/run_current_experiment.py b/classification/run_current_experiment.py new file mode 100644 index 0000000..b5b8a9f --- /dev/null +++ b/classification/run_current_experiment.py @@ -0,0 +1,214 @@ +from experiment_params import ExperimentParams, get_categories +import train_classifier +import numpy as np +import time +import regularization_search_experiments + + +def main(): + loaded_embedding = preload_embed() + + exp_num = 11 + + start_time = time.time() + counter = [0] + categories = get_categories() + + + # a basic experiment + if exp_num == 0: + args = ExperimentParams(use_rho=True, pattern="4-gram", sparsity_type = "rho_entropy", rho_sum_to_one=True, + reg_strength=0.01, d_out="23", lr=0.001, seed = 34159) + train_classifier.main(args) + + + # finding the largest learning rate that doesn't diverge, for evaluating the claims in this paper: + # The Marginal Value of Adaptive Gradient Methods in Machine Learning + # https://arxiv.org/abs/1705.08292 + # conclusion: their results don't hold for our models. + elif exp_num == 1: + lrs = np.linspace(2,0.1, 10) + for lr in lrs: + args = ExperimentParams(pattern="4-gram", d_out="256", trainer="sgd", max_epoch=3, lr=lr, filename_prefix="lr_tuning/") + train_classifier.main(args) + + # baseline experiments for 1-gram up to 4-gram models + elif exp_num == 3: + patterns = ["4-gram", "3-gram", "2-gram", "1-gram"] + m = 20 + n = 5 + total_evals = len(categories) * (len(patterns) + 1) * (m+n) + + for category in categories: + for pattern in patterns: + train_m_then_n_models(m,n,counter, total_evals, start_time, + pattern=pattern, d_out = "24", depth = 1, filename_prefix="all_cs_and_equal_rho/hparam_opt/", + dataset = "amazon_categories/" + category, use_rho=False, + seed=None, loaded_embedding=loaded_embedding) + + train_m_then_n_models(m,n,counter, total_evals, start_time, + pattern="1-gram,2-gram,3-gram,4-gram", d_out = "6,6,6,6", depth = 1, + filename_prefix="all_cs_and_equal_rho/hparam_opt/", + dataset = "amazon_categories/" + category, use_rho = False, seed=None, + loaded_embedding = loaded_embedding) + + # to learn with an L_1 regularizer + # first train with the regularizer, choose the best structure, then do hyperparameter search for that structure + elif exp_num == 6: + d_out = "24" + k = 20 + l = 5 + m = 20 + n = 5 + reg_goal_params_list = [80, 60, 40, 20] + total_evals = len(categories) * (m + n + k + l) * len(reg_goal_params_list) + + all_reg_search_counters = [] + + for category in categories: + for reg_goal_params in reg_goal_params_list: + best, reg_search_counters = regularization_search_experiments.train_k_then_l_models( + k,l, counter, total_evals, start_time, reg_goal_params = reg_goal_params, + pattern = "4-gram", d_out = d_out, sparsity_type = "states", + use_rho = False, + filename_prefix="all_cs_and_equal_rho/hparam_opt/structure_search/add_reg_term_to_loss/", + seed=None, + loaded_embedding=loaded_embedding, reg_strength = 10**-6, + dataset = "amazon_categories/" + category) + + all_reg_search_counters.append(reg_search_counters) + + args = train_m_then_n_models(m,n,counter, total_evals, start_time, + pattern = best['learned_pattern'], d_out = best["learned_d_out"], + learned_structure = "l1-states-learned", reg_goal_params = reg_goal_params, + filename_prefix="all_cs_and_equal_rho/hparam_opt/structure_search/add_reg_term_to_loss/", + seed = None, loaded_embedding = loaded_embedding, + dataset = "amazon_categories/" + category, use_rho = False) + print("search counters:") + for search_counter in all_reg_search_counters: + print(search_counter) + + + # some rho_entropy experiments + elif exp_num == 8: + k = 20 + l = 5 + total_evals = len(categories) * (k + l) + + for d_out in ["24"]:#, "256"]: + for category in categories: + # to learn the structure, and train with the regularizer + best, reg_search_counters = regularization_search_experiments.train_k_then_l_models( + k, l, counter, total_evals, start_time, + use_rho = True, pattern = "4-gram", sparsity_type = "rho_entropy", + rho_sum_to_one=True, reg_strength = 1, d_out=d_out, + filename_prefix="only_last_cs/hparam_opt/reg_str_search/", + dataset = "amazon_categories/" + category, seed=None, + loaded_embedding=loaded_embedding) + + # baseline for rho_entropy experiments + elif exp_num == 9: + categories = ["dvd/"] + patterns = ["1-gram", "2-gram"] #["4-gram", "3-gram", "2-gram", "1-gram"] + m = 20 + n = 5 + total_evals = len(categories) * (len(patterns) + 1) * (m+n) + + for category in categories: + for pattern in patterns: + # train and eval the learned structure + args = train_m_then_n_models(m,n,counter, total_evals,start_time, + pattern = pattern, d_out="24", + filename_prefix="only_last_cs/hparam_opt/", + dataset = "amazon_categories/" + category, use_last_cs=True, + use_rho = False, seed=None, loaded_embedding=loaded_embedding) + + # baseline experiments for l1 regularization, on sst. very similar to exp_num 3 + elif exp_num == 10: + patterns = ["4-gram", "3-gram", "2-gram", "1-gram"] + m = 20 + n = 5 + total_evals = m * n + for pattern in patterns: + train_m_then_n_models(m,n,counter, total_evals, start_time, + pattern=pattern, d_out = "24", depth = 1, filename_prefix="all_cs_and_equal_rho/hparam_opt/", + dataset = "sst/", use_rho=False, + seed=None, loaded_embedding=loaded_embedding) + + train_m_then_n_models(m,n,counter, total_evals, start_time, + pattern="1-gram,2-gram,3-gram,4-gram", d_out = "6,6,6,6", depth = 1, + filename_prefix="all_cs_and_equal_rho/hparam_opt/", + dataset = "sst/", use_rho = False, seed=None, + loaded_embedding = loaded_embedding) + + elif exp_num == 11: + + args = ExperimentParams(pattern = "1-gram,2-gram,3-gram,4-gram", d_out = "0,4,0,2", + learned_structure = "l1-states-learned", reg_goal_params = 20, + filename_prefix="all_cs_and_equal_rho/saving_model_for_interpretability/", + seed = None, loaded_embedding = loaded_embedding, + dataset = "amazon_categories/original_mix/", use_rho = False, + clip_grad = 1.09, dropout = 0.1943, rnn_dropout = 0.0805, embed_dropout = 0.3489, + lr = 2.553E-02, weight_decay = 1.64E-06, depth = 1) + cur_valid_err, cur_test_err = train_classifier.main(args) + + + +def preload_embed(): + start = time.time() + import dataloader + embs = dataloader.load_embedding("/home/jessedd/data/amazon/embedding") + print("took {} seconds".format(time.time()-start)) + print("preloaded embeddings from amazon dataset.") + print("") + return embs + +# hparams to search over (from paper): +# clip_grad, dropout, learning rate, rnn_dropout, embed_dropout, l2 regularization (actually weight decay) +def hparam_sample(lr_bounds = [1.5, 10**-3]): + assignments = { + "clip_grad" : np.random.uniform(1.0, 5.0), + "dropout" : np.random.uniform(0.0, 0.5), + "rnn_dropout" : np.random.uniform(0.0, 0.5), + "embed_dropout" : np.random.uniform(0.0, 0.5), + "lr" : np.exp(np.random.uniform(np.log(lr_bounds[0]), np.log(lr_bounds[1]))), + "weight_decay" : np.exp(np.random.uniform(np.log(10**-5), np.log(10**-7))), + } + + return assignments + +#orders them in increasing order of lr +def get_k_sorted_hparams(k,lr_upper_bound=1.5, lr_lower_bound=10**-3): + all_assignments = [] + + for i in range(k): + cur = hparam_sample(lr_bounds=[lr_lower_bound,lr_upper_bound]) + all_assignments.append([cur['lr'], cur]) + all_assignments.sort() + return [assignment[1] for assignment in all_assignments] + +def train_m_then_n_models(m,n,counter, total_evals,start_time,**kwargs): + best_assignment = None + best_valid_err = 1 + all_assignments = get_k_sorted_hparams(m) + for i in range(m): + cur_assignments = all_assignments[i] + args = ExperimentParams(**kwargs, **cur_assignments) + cur_valid_err, cur_test_err = train_classifier.main(args) + if cur_valid_err < best_valid_err: + best_assignment = cur_assignments + best_valid_err = cur_valid_err + counter[0] = counter[0] + 1 + print("trained {} out of {} hyperparameter assignments, so far {} seconds".format( + counter[0],total_evals, round(time.time()-start_time, 3))) + + for i in range(n): + args = ExperimentParams(filename_suffix="_{}".format(i),**kwargs,**best_assignment) + cur_valid_err, cur_test_err = train_classifier.main(args) + counter[0] = counter[0] + 1 + print("trained {} out of {} hyperparameter assignments, so far {} seconds".format( + counter[0],total_evals, round(time.time()-start_time, 3))) + return best_assignment + +if __name__ == "__main__": + main() diff --git a/classification/run_local_experiment.py b/classification/run_local_experiment.py new file mode 100755 index 0000000..673fe0d --- /dev/null +++ b/classification/run_local_experiment.py @@ -0,0 +1,233 @@ +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter + +import sys +import os +from experiment_params import ExperimentParams, get_categories +import train_classifier +import numpy as np +import time +import regularization_search_experiments +import experiment_tools + + +def main(argv): + loaded_embedding = experiment_tools.preload_embed(os.path.join(argv.base_dir,argv.dataset)) + + exp_num = 0 + + start_time = time.time() + counter = [0] + categories = get_categories() + + + # a basic experiment + if exp_num == 0: + args = ExperimentParams(pattern = argv.pattern, d_out = argv.d_out, + learned_structure = argv.learned_structure, reg_goal_params = argv.reg_goal_params, + filename_prefix=argv.filename_prefix, + seed = argv.seed, loaded_embedding = loaded_embedding, + dataset = argv.dataset, use_rho = False, + clip_grad = argv.clip, dropout = argv.dropout, rnn_dropout = argv.rnn_dropout, + embed_dropout = argv.embed_dropout, gpu=argv.gpu, + max_epoch = argv.max_epoch, patience = argv.patience, + batch_size = argv.batch_size, use_last_cs=argv.use_last_cs, + lr = argv.lr, weight_decay = argv.weight_decay, depth = argv.depth, logging_dir = argv.logging_dir, + base_data_dir = argv.base_dir, output_dir = argv.model_save_dir) + cur_valid_err = train_classifier.main(args) + + # finding the largest learning rate that doesn't diverge, for evaluating the claims in this paper: + # The Marginal Value of Adaptive Gradient Methods in Machine Learning + # https://arxiv.org/abs/1705.08292 + # conclusion: their results don't hold for our models. + elif exp_num == 1: + lrs = np.linspace(2,0.1, 10) + for lr in lrs: + args = ExperimentParams(pattern="4-gram", d_out="256", trainer="sgd", max_epoch=3, lr=lr, filename_prefix="lr_tuning/") + train_classifier.main(args) + + # baseline experiments for 1-gram up to 4-gram models + elif exp_num == 3: + patterns = ["4-gram", "3-gram", "2-gram", "1-gram"] + m = 20 + n = 5 + total_evals = len(categories) * (len(patterns) + 1) * (m+n) + + for category in categories: + for pattern in patterns: + train_m_then_n_models(m,n,counter, total_evals, start_time, + pattern=pattern, d_out = "24", depth = 1, filename_prefix="all_cs_and_equal_rho/hparam_opt/", + dataset = "amazon_categories/" + category, use_rho=False, + seed=None, loaded_embedding=loaded_embedding) + + train_m_then_n_models(m,n,counter, total_evals, start_time, + pattern="1-gram,2-gram,3-gram,4-gram", d_out = "6,6,6,6", depth = 1, + filename_prefix="all_cs_and_equal_rho/hparam_opt/", + dataset = "amazon_categories/" + category, use_rho = False, seed=None, + loaded_embedding = loaded_embedding) + + # to learn with an L_1 regularizer + # first train with the regularizer, choose the best structure, then do hyperparameter search for that structure + elif exp_num == 6: + d_out = "24" + k = 20 + l = 5 + m = 20 + n = 5 + reg_goal_params_list = [80, 60, 40, 20] + total_evals = len(categories) * (m + n + k + l) * len(reg_goal_params_list) + + all_reg_search_counters = [] + + for category in categories: + for reg_goal_params in reg_goal_params_list: + best, reg_search_counters = regularization_search_experiments.train_k_then_l_models( + k,l, counter, total_evals, start_time, reg_goal_params = reg_goal_params, + pattern = "4-gram", d_out = d_out, sparsity_type = "states", + use_rho = False, + filename_prefix="all_cs_and_equal_rho/hparam_opt/structure_search/add_reg_term_to_loss/", + seed=None, + loaded_embedding=loaded_embedding, reg_strength = 10**-6, + dataset = "amazon_categories/" + category) + + all_reg_search_counters.append(reg_search_counters) + + args = train_m_then_n_models(m,n,counter, total_evals, start_time, + pattern = best['learned_pattern'], d_out = best["learned_d_out"], + learned_structure = "l1-states-learned", reg_goal_params = reg_goal_params, + filename_prefix="all_cs_and_equal_rho/hparam_opt/structure_search/add_reg_term_to_loss/", + seed = None, loaded_embedding = loaded_embedding, + dataset = "amazon_categories/" + category, use_rho = False) + print("search counters:") + for search_counter in all_reg_search_counters: + print(search_counter) + + + # some rho_entropy experiments + elif exp_num == 8: + k = 20 + l = 5 + total_evals = len(categories) * (k + l) + + for d_out in ["24"]:#, "256"]: + for category in categories: + # to learn the structure, and train with the regularizer + best, reg_search_counters = regularization_search_experiments.train_k_then_l_models( + k, l, counter, total_evals, start_time, + use_rho = True, pattern = "4-gram", sparsity_type = "rho_entropy", + rho_sum_to_one=True, reg_strength = 1, d_out=d_out, + filename_prefix="only_last_cs/hparam_opt/reg_str_search/", + dataset = "amazon_categories/" + category, seed=None, + loaded_embedding=loaded_embedding) + + # baseline for rho_entropy experiments + elif exp_num == 9: + categories = ["dvd/"] + patterns = ["1-gram", "2-gram"] #["4-gram", "3-gram", "2-gram", "1-gram"] + m = 20 + n = 5 + total_evals = len(categories) * (len(patterns) + 1) * (m+n) + + for category in categories: + for pattern in patterns: + # train and eval the learned structure + args = train_m_then_n_models(m,n,counter, total_evals,start_time, + pattern = pattern, d_out="24", + filename_prefix="only_last_cs/hparam_opt/", + dataset = "amazon_categories/" + category, use_last_cs=True, + use_rho = False, seed=None, loaded_embedding=loaded_embedding) + + # baseline experiments for l1 regularization, on sst. very similar to exp_num 3 + elif exp_num == 10: + patterns = ["4-gram", "3-gram", "2-gram", "1-gram"] + m = 20 + n = 5 + total_evals = m * n + for pattern in patterns: + train_m_then_n_models(m,n,counter, total_evals, start_time, + pattern=pattern, d_out = "24", depth = 1, filename_prefix="all_cs_and_equal_rho/hparam_opt/", + dataset = "sst/", use_rho=False, + seed=None, loaded_embedding=loaded_embedding) + + train_m_then_n_models(m,n,counter, total_evals, start_time, + pattern="1-gram,2-gram,3-gram,4-gram", d_out = "6,6,6,6", depth = 1, + filename_prefix="all_cs_and_equal_rho/hparam_opt/", + dataset = "sst/", use_rho = False, seed=None, + loaded_embedding = loaded_embedding) + + + +# hparams to search over (from paper): +# clip_grad, dropout, learning rate, rnn_dropout, embed_dropout, l2 regularization (actually weight decay) +def hparam_sample(lr_bounds = [1.5, 10**-3]): + assignments = { + "clip_grad" : np.random.uniform(1.0, 5.0), + "dropout" : np.random.uniform(0.0, 0.5), + "rnn_dropout" : np.random.uniform(0.0, 0.5), + "embed_dropout" : np.random.uniform(0.0, 0.5), + "lr" : np.exp(np.random.uniform(np.log(lr_bounds[0]), np.log(lr_bounds[1]))), + "weight_decay" : np.exp(np.random.uniform(np.log(10**-5), np.log(10**-7))), + } + + return assignments + +#orders them in increasing order of lr +def get_k_sorted_hparams(k,lr_upper_bound=1.5, lr_lower_bound=10**-3): + all_assignments = [] + + for i in range(k): + cur = hparam_sample(lr_bounds=[lr_lower_bound,lr_upper_bound]) + all_assignments.append([cur['lr'], cur]) + all_assignments.sort() + return [assignment[1] for assignment in all_assignments] + +def train_m_then_n_models(m,n,counter, total_evals,start_time,**kwargs): + best_assignment = None + best_valid_err = 1 + all_assignments = get_k_sorted_hparams(m) + for i in range(m): + cur_assignments = all_assignments[i] + args = ExperimentParams(**kwargs, **cur_assignments) + cur_valid_err = train_classifier.main(args) + if cur_valid_err < best_valid_err: + best_assignment = cur_assignments + best_valid_err = cur_valid_err + counter[0] = counter[0] + 1 + print("trained {} out of {} hyperparameter assignments, so far {} seconds".format( + counter[0],total_evals, round(time.time()-start_time, 3))) + + for i in range(n): + args = ExperimentParams(filename_suffix="_{}".format(i),**kwargs,**best_assignment) + cur_valid_err = train_classifier.main(args) + counter[0] = counter[0] + 1 + print("trained {} out of {} hyperparameter assignments, so far {} seconds".format( + counter[0],total_evals, round(time.time()-start_time, 3))) + return best_assignment + + + +def training_arg_parser(): + """ CLI args related to training models. """ + p = ArgumentParser(add_help=False) + p.add_argument("--learned_structure", help="Learned structure", type=str, default="l1-states-learned") + p.add_argument('--reg_goal_params', type=int, default = 20) + p.add_argument('--filename_prefix', help='logging file prefix?', type=str, default="all_cs_and_equal_rho/saving_model_for_interpretability/") + p.add_argument("-t", "--dropout", help="Use dropout", type=float, default=0.1943) + p.add_argument("--rnn_dropout", help="Use RNN dropout", type=float, default=0.0805) + p.add_argument("--embed_dropout", help="Use RNN dropout", type=float, default=0.3489) + p.add_argument("-l", "--lr", help="Learning rate", type=float, default=2.553E-02) + p.add_argument("--clip", help="Gradient clipping", type=float, default=1.09) + p.add_argument('-w', "--weight_decay", help="Weight decay", type=float, default=1.64E-06) + p.add_argument("-m", "--model_save_dir", help="where to save the trained model", type=str) + p.add_argument("--logging_dir", help="Logging directory", type=str) + p.add_argument("--max_epoch", help="Number of iterations", type=int, default=500) + p.add_argument("--patience", help="Patience parameter (for early stopping)", type=int, default=30) + # p.add_argument("-r", "--scheduler", help="Use reduce learning rate on plateau schedule", action='store_true') + # p.add_argument("--debug", help="Debug", type=int, default=0) + return p + + +if __name__ == "__main__": + parser = ArgumentParser(description=__doc__, + formatter_class=ArgumentDefaultsHelpFormatter, + parents=[experiment_tools.general_arg_parser(), training_arg_parser()]) + sys.exit(main(parser.parse_args())) diff --git a/classification/test_experiment.py b/classification/test_experiment.py new file mode 100755 index 0000000..8a039eb --- /dev/null +++ b/classification/test_experiment.py @@ -0,0 +1,38 @@ +import os +import sys +import experiment_tools +import train_classifier +from experiment_params import get_categories, ExperimentParams +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter + + +def main(argv): + loaded_embedding = experiment_tools.preload_embed(os.path.join(argv.base_dir,argv.dataset)) + + # a basic experiment + args = ExperimentParams(pattern = argv.pattern, d_out = argv.d_out, + seed = argv.seed, loaded_embedding = loaded_embedding, + dataset = argv.dataset, use_rho = False, + depth = argv.depth, gpu=argv.gpu, + batch_size=argv.batch_size, + base_data_dir = argv.base_dir, input_model=argv.input_model) + + if argv.visualize > 0: + train_classifier.main_visualize(args, os.path.join(argv.base_dir,argv.dataset), argv.visualize, + argv.norms_file) + else: + _,_,_ = train_classifier.main_test(args) + + return 0 + + + +if __name__ == '__main__': + parser = ArgumentParser(description=__doc__, + formatter_class=ArgumentDefaultsHelpFormatter, + parents=[experiment_tools.general_arg_parser()]) + parser.add_argument("-m", "--input_model", help="Saved model file", required=True, type=str) + parser.add_argument("-v", "--visualize", help="Visualize (rather than test): top_k phrases to visualize", type=int, default=0) + parser.add_argument("--norms_file", help="In visualization mode: file with norms, from which to select patterns and pattern states.", type=str) + + sys.exit(main(parser.parse_args())) diff --git a/classification/train_classifier.py b/classification/train_classifier.py index 972caaa..0e378db 100644 --- a/classification/train_classifier.py +++ b/classification/train_classifier.py @@ -1,4 +1,5 @@ import sys +import os import argparse import numpy as np @@ -6,6 +7,10 @@ import torch.optim as optim from torch.autograd import Variable from torch.optim.lr_scheduler import ReduceLROnPlateau +from termcolor import colored + + +from tensorboardX import SummaryWriter sys.path.append("..") import classification.dataloader as dataloader @@ -13,6 +18,8 @@ from semiring import * import rrnn + + SOS, EOS = "", "" class Model(nn.Module): def __init__(self, args, emb_layer, nclasses=2): @@ -29,15 +36,16 @@ def __init__(self, args, emb_layer, nclasses=2): use_selu = 1 else: assert args.activation == "none" + if args.model == "lstm": self.encoder = nn.LSTM( emb_layer.n_d, - args.d, + args.d_out, args.depth, dropout=args.dropout, bidirectional=False ) - d_out = args.d + d_out = args.d_out elif args.model == "rrnn": if args.semiring == "plus_times": self.semiring = PlusTimesSemiring @@ -51,8 +59,9 @@ def __init__(self, args, emb_layer, nclasses=2): self.encoder = rrnn.RRNN( self.semiring, emb_layer.n_d, - args.d, + args.d_out, args.depth, + pattern=args.pattern, dropout=args.dropout, rnn_dropout=args.rnn_dropout, bidirectional=False, @@ -60,12 +69,17 @@ def __init__(self, args, emb_layer, nclasses=2): use_relu=use_relu, use_selu=use_selu, layer_norm=args.use_layer_norm, - use_output_gate=args.use_output_gate + use_output_gate=args.use_output_gate, + use_rho=args.use_rho, + rho_sum_to_one=args.rho_sum_to_one, + use_last_cs=args.use_last_cs, + use_epsilon_steps=args.use_epsilon_steps ) - d_out = args.d + d_out = args.d_out else: assert False - self.out = nn.Linear(d_out, nclasses) + out_size = sum([int(one_size) for one_size in d_out.split(",")]) + self.out = nn.Linear(out_size, nclasses) def init_hidden(self, batch_size): @@ -80,13 +94,15 @@ def forward(self, input): input_fwd = input emb_fwd = self.emb_layer(input_fwd) emb_fwd = self.drop(emb_fwd) - out_fwd, hidden_fwd = self.encoder(emb_fwd) + + out_fwd, hidden_fwd, _ = self.encoder(emb_fwd) batch, length = emb_fwd.size(-2), emb_fwd.size(0) out_fwd = out_fwd.view(length, batch, 1, -1) feat = out_fwd[-1,:,0,:] else: emb = self.emb_layer(input) emb = self.drop(emb) + output, hidden = self.encoder(emb) batch, length = emb.size(-2), emb.size(0) output = output.view(length, batch, 1, -1) @@ -95,6 +111,15 @@ def forward(self, input): feat = self.drop(feat) return self.out(feat) + # Assume rrnn model + def visualize(self, input, min_max=rrnn.Max): + assert self.args.model == "rrnn" + input_fwd = input + emb_fwd = self.emb_layer(input_fwd) +# emb_fwd = self.drop(emb_fwd) + _, _, traces = self.encoder(emb_fwd, None, True, True, min_max) + + return traces def eval_model(niter, model, valid_x, valid_y): model.eval() @@ -118,10 +143,163 @@ def eval_model(niter, model, valid_x, valid_y): return 1.0 - correct / cnt +def get_states_weights(model, args): + + embed_dim = model.emb_layer.n_d + num_edges_in_wfsa = model.encoder.rnn_lst[0].cells[0].k + num_wfsas = int(args.d_out) + + reshaped_weights = model.encoder.rnn_lst[0].cells[0].weight.view(embed_dim, num_wfsas, num_edges_in_wfsa) + if len(model.encoder.rnn_lst) > 1: + reshaped_second_layer_weights = model.encoder.rnn_lst[1].cells[0].weight.view(num_wfsas, num_wfsas, num_edges_in_wfsa) + reshaped_weights = torch.cat((reshaped_weights, reshaped_second_layer_weights), 0) + elif len(model.encoder.rnn_lst) > 2: + assert False, "This regularization is only implemented for 2-layer networks." + + # to stack the transition and self-loops, so e.g. states[...,0] contains the transition and self-loop weights + + states = torch.cat((reshaped_weights[...,0:int(num_edges_in_wfsa/2)], + reshaped_weights[...,int(num_edges_in_wfsa/2):num_edges_in_wfsa]),0) + return states + +# this computes the group lasso penalty term +def get_regularization_groups(model, args): + if args.sparsity_type == "wfsa": + embed_dim = model.emb_layer.n_d + num_edges_in_wfsa = model.encoder.rnn_lst[0].k + reshaped_weights = model.encoder.rnn_lst[0].weight.view(embed_dim, args.d_out, num_edges_in_wfsa) + l2_norm = reshaped_weights.norm(2, dim=0).norm(2, dim=1) + return l2_norm + elif args.sparsity_type == 'edges': + return model.encoder.rnn_lst[0].weight.norm(2, dim=0) + elif args.sparsity_type == 'states': + states = get_states_weights(model, args) + return states.norm(2,dim=0) # a num_wfsa by n-gram matrix + elif args.sparsity_type == "rho_entropy": + assert args.depth == 1, "rho_entropy regularization currently implemented for single layer networks" + bidirectional = model.encoder.rnn_lst[0].cells[0].bidirectional + assert not bidirectional, "bidirectional not implemented" + + num_edges_in_wfsa = model.encoder.rnn_lst[0].cells[0].k + num_wfsas = int(args.d_out) + bias_final = model.encoder.rnn_lst[0].cells[0].bias_final + + sm = nn.Softmax(dim=2) + # the 1 in the line below is for non-bidirectional models, would be 2 for bidirectional + rho = sm(bias_final.view(1, num_wfsas, int(num_edges_in_wfsa/2))) + entropy_to_sum = rho * rho.log() * -1 + entropy = entropy_to_sum.sum(dim=2) + return entropy + + + + +def log_groups(model, args, logging_file, groups=None): + if groups is not None: + + if args.sparsity_type == "rho_entropy": + num_edges_in_wfsa = model.encoder.rnn_lst[0].cells[0].k + num_wfsas = int(args.d_out) + bias_final = model.encoder.rnn_lst[0].cells[0].bias_final + + sm = nn.Softmax(dim=2) + # the 1 in the line below is for non-bidirectional models, would be 2 for bidirectional + rho = sm(bias_final.view(1, num_wfsas, int(num_edges_in_wfsa/2))) + logging_file.write(str(rho)) + + else: + logging_file.write(str(groups)) + else: + if args.sparsity_type == "wfsa": + embed_dim = model.emb_layer.n_d + num_edges_in_wfsa = model.encoder.rnn_lst[0].k + reshaped_weights = model.encoder.rnn_lst[0].weight.view(embed_dim, args.d_out, num_edges_in_wfsa) + l2_norm = reshaped_weights.norm(2, dim=0).norm(2, dim=1) + logging_file.write(str(l2_norm)) + + elif args.sparsity_type == 'edges': + embed_dim = model.emb_layer.n_d + num_edges_in_wfsa = model.encoder.rnn_lst[0].k + reshaped_weights = model.encoder.rnn_lst[0].weight.view(embed_dim, args.d_out, num_edges_in_wfsa) + logging_file.write(str(reshaped_weights.norm(2, dim=0))) + #model.encoder.rnn_lst[0].weight.norm(2, dim=0) + elif args.sparsity_type == 'states': + assert False, "can implement this based on get_regularization_groups, but that keeps changing" + logging_file.write(str(states.norm(2,dim=0))) # a num_wfsa by n-gram matrix + + +def init_logging(args): + + dir_path = args.logging_dir + args.dataset + "/" + filename = args.filename() + ".txt" + + if not os.path.exists(dir_path): + os.mkdir(dir_path) + + if not os.path.exists(dir_path + args.filename_prefix): + os.mkdir(dir_path + args.filename_prefix) + + torch.set_printoptions(threshold=5000) + + logging_file = open(dir_path + filename, "w") + + tmp = args.loaded_embedding + args.loaded_embedding=True + logging_file.write(str(args)) + args.loaded_embedding = tmp + + #print(args) + print("saving in {}".format(args.dataset + args.filename())) + return logging_file + + +def regularization_stop(args, model): + if args.sparsity_type == "states" and args.prox_step: + states = get_states_weights(model, args) + if states.norm(2,dim=0).sum().data[0] == 0: + return True + else: + return False + +# following https://en.wikipedia.org/wiki/Proximal_gradient_methods_for_learning#Group_lasso +# w_g - args.reg_strength * (w_g / ||w_g||_2) +def prox_step(model, args): + if args.sparsity_type == "states": + + states = get_states_weights(model, args) + num_states = states.shape[2] + + embed_dim = model.emb_layer.n_d + num_edges_in_wfsa = model.encoder.rnn_lst[0].cells[0].k + num_wfsas = int(args.d_out) + + reshaped_weights = model.encoder.rnn_lst[0].cells[0].weight.view(embed_dim, num_wfsas, num_edges_in_wfsa) + if len(model.encoder.rnn_lst) > 1: + assert False, "This regularization is only implemented for 2-layer networks." + + first_weights = reshaped_weights[...,0:int(num_edges_in_wfsa/2)] + second_weights = reshaped_weights[...,int(num_edges_in_wfsa/2):num_edges_in_wfsa] + + + for i in range(num_wfsas): + for j in range(num_states): + cur_group = states[:,i,j].data + cur_first_weights = first_weights[:,i,j].data + cur_second_weights = second_weights[:,i,j].data + if cur_group.norm(2) < args.reg_strength: + #cur_group.add_(-cur_group) + cur_first_weights.add_(-cur_first_weights) + cur_second_weights.add_(-cur_second_weights) + else: + #cur_group.add_(-args.reg_strength*cur_group/cur_group.norm(2)) + cur_first_weights.add_(-args.reg_strength*cur_first_weights/cur_group.norm(2)) + cur_second_weights.add_(-args.reg_strength*cur_second_weights/cur_group.norm(2)) + else: + assert False, "haven't implemented anything else" + def train_model(epoch, model, optimizer, train_x, train_y, valid_x, valid_y, - test_x, test_y, - best_valid, test_err, unchanged, scheduler): + best_valid, unchanged, scheduler, logging_file): model.train() args = model.args N = len(train_x) @@ -129,7 +307,10 @@ def train_model(epoch, model, optimizer, criterion = nn.CrossEntropyLoss() cnt = 0 stop = False + + import time for x, y in zip(train_x, train_y): + iter_start_time = time.time() niter += 1 cnt += 1 model.zero_grad() @@ -137,49 +318,309 @@ def train_model(epoch, model, optimizer, if args.gpu: x, y = x.cuda(), y.cuda() x = (x) + output = model(x) loss = criterion(output, y) - loss.backward() - torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad) + + if args.sparsity_type == "none": + reg_loss = loss + regularization_term = 0 + else: + regularization_groups = get_regularization_groups(model, args) + + regularization_term = regularization_groups.sum() + + if args.reg_strength_multiple_of_loss and args.reg_strength == 0: + args.reg_strength = loss.data[0]*args.reg_strength_multiple_of_loss/regularization_term.data[0] + + if args.prox_step: + reg_loss = loss + else: + reg_loss = loss + args.reg_strength * regularization_term + + reg_loss.backward() + torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad) optimizer.step() + if args.prox_step: + prox_step(model, args) + + if args.num_epochs_debug != -1 and epoch > args.num_epochs_debug: + import pdb; pdb.set_trace() + + # to log every batch's loss, and how long it took + #logging_file.write("took {} seconds. reg_term: {}, reg_loss: {}\n".format(round(time.time() - iter_start_time,2), + # round(float(regularization_term),4), round(float(reg_loss),4))) + regularization_groups = get_regularization_groups(model, args) + log_groups(model, args, logging_file, regularization_groups) + valid_err = eval_model(niter, model, valid_x, valid_y) scheduler.step(valid_err) - sys.stdout.write("-" * 89 + "\n") - sys.stdout.write("| Epoch={} | iter={} | lr={:.6f} | train_loss={:.6f} | valid_err={:.6f} |\n".format( + epoch_string = "\n" + epoch_string += "-" * 110 + "\n" + epoch_string += "| Epoch={} | iter={} | lr={:.5f} | reg_strength={} | train_loss={:.6f} | valid_err={:.6f} | regularized_loss={:.6f} |\n".format( epoch, niter, optimizer.param_groups[0]["lr"], + args.reg_strength, loss.data[0], - valid_err - )) - sys.stdout.write("-" * 89 + "\n") - sys.stdout.flush() + valid_err, + reg_loss.data[0] + ) + epoch_string += "-" * 110 + "\n" + logging_file.write(epoch_string) + sys.stdout.write(epoch_string) + sys.stdout.flush() + if valid_err < best_valid: unchanged = 0 best_valid = valid_err - test_err = eval_model(niter, model, test_x, test_y) else: unchanged += 1 - if unchanged >= args.patience: + if unchanged >= args.patience or regularization_stop(args, model): stop = True + sys.stdout.write("\n") sys.stdout.flush() - return best_valid, unchanged, test_err, stop + return best_valid, unchanged, stop + + +def read_norms_file(norms_file): + norms = np.loadtxt(norms_file) + + vals = -1 * np.ones(norms.shape[0]) + + for i in range(norms.shape[0]): + if norms[i, 0] > 0.1: + vals[i] = norms.shape[1] - 1 + for j in range(1, norms.shape[1]): + if norms[i, j]*10 < np.max(norms[i, j-1], 1): + vals[i] = j - 1 + break + + return vals + +def main_visualize(args, dataset_file, top_k, norms_file=None): + # datasets and labels are 3-size array: 0 - train, 1 - dev, 2 - test + model, datasets, labels, emb_layer = main_init(args) + + model.eval() + + # Creating dev batches + d, l, txt_batches = dataloader.create_batches( + datasets[1], labels[1], + args.batch_size, + emb_layer.word2id, + sort=True, + gpu=args.gpu, + sos=SOS, + eos=EOS, + get_text_batches=True + ) + + # Loading trained model + if args.gpu: + state_dict = torch.load(args.input_model) + else: + state_dict = torch.load(args.input_model, map_location=lambda storage, loc: storage) + + model.load_state_dict(state_dict) + + if args.gpu: + model.cuda() + + + norms = None if norms_file is None else read_norms_file(norms_file) + + #top_samples = torch.zeros(len(traces[0][0])), ) + + n_patts = [int(one_size) for one_size in args.d_out.split(",")] + + patt_lengths = [int(patt_len[0]) for patt_len in args.pattern.split(",")] + + # Filtering-out pattern lengths with 0 patterns + patt_lengths = [patt_lengths[i] for i in range(len(patt_lengths)) if n_patts[i] > 0 ] + n_patts = [n_patts[i] for i in range(len(n_patts)) if n_patts[i] > 0 ] + + # Trace for each pattern in each pattern length + all_scores = [[] for i in n_patts] + all_traces = [[] for i in n_patts] + all_scores_min = [[] for i in n_patts] + all_traces_min = [[] for i in n_patts] + + all_x = [] + + for x, txt_x in zip(d, txt_batches): + all_x.extend(txt_x) + # print(len(x[0]), len(txt_x)) + assert(len(x[0]) == len(txt_x)) + + x = Variable(x) + if args.gpu: + x = x.cuda() + x = (x) + + # traces shape: n-patterns lengths X length of pattern (for intermediate pattern score) + # each item in this matrix is a TraceElementParallel representing the traces of a batch of documents + # and all patterns of the given pattern length. + traces = model.visualize(x, rrnn.Max) + traces_min = model.visualize(x, rrnn.Min) + + # Saving all scores and all indices of main path (u_indices) + for i in range(len(n_patts)): + if len(all_traces[i]) == 0: + for j in range(len(traces[i])): + all_scores[i].append(traces[i][j].score) + all_traces[i].append(traces[i][j].u_indices) + all_scores_min[i].append(traces_min[i][j].score) + all_traces_min[i].append(traces_min[i][j].u_indices) + else: + for j in range(len(traces[i])): + all_scores[i][j] = np.concatenate((all_scores[i][j], traces[i][j].score)) + all_traces[i][j] = np.concatenate((all_traces[i][j], traces[i][j].u_indices)) + all_scores_min[i][j] = np.concatenate((all_scores_min[i][j], traces_min[i][j].score)) + all_traces_min[i][j] = np.concatenate((all_traces_min[i][j], traces_min[i][j].u_indices)) + + # loop one: pattern length + for (i, same_length_traces) in enumerate(all_traces): + print("\nPattern length {}".format(patt_lengths[i])) + + # loop two: number of patterns of each length + for k in range(n_patts[i]): + if norms is not None and norms[k] == -1: + continue + + print("\nPattern index {}\n".format(k)) + # Loop three: top phrases of intermediate states for each pattern + for j in range(len(same_length_traces)): + if norms is not None and norms[k] == (j - 2): + break -def main(args): - np.random.seed(args.seed) - torch.manual_seed(args.seed) + print("\nSublength {}\n".format(j)) + + patt_traces = same_length_traces[j] + patt_traces_min = all_traces_min[i][j] + + local_scores = all_scores[i][j][:, k] + local_traces = patt_traces[:, k, :] + local_scores_min = all_scores_min[i][j][:, k] + local_traces_min = patt_traces_min[:, k, :] + + # Sorting scores, traces and documents by the score. + sorted_traces = sorted(zip(local_scores, local_traces, all_x), + key=lambda pair: pair[0], reverse=True) + sorted_traces_min = sorted(zip(local_scores_min, local_traces_min, all_x), + key=lambda pair: pair[0], reverse=False) + + local_top_traces = sorted_traces[:top_k] + local_worst_traces = sorted_traces_min[:top_k] + + print("Best:") + print_top_traces(local_top_traces, j+1) + print("Worst:") + print_top_traces(local_worst_traces, j+1) + + + sys.stdout.flush() + + +# Print the top phrases for a given pattern +# top_traces: triplets containing the scores, traces and documents for the top k matches +# tmp_patt_len: the pattern length to print (if tracing only the first tokens of the pattern) +def print_top_traces(top_traces, tmp_patt_len=None): + # A helper function: print a given phrase. + # Calls recursive function (print_rec) that prints a given state and all it self loops. + def print_traces(index, score, u_indices, doc, tmp_patt_len=None): + if tmp_patt_len is None: + tmp_patt_len = len(u_indices) + + print("{}. {}.".format(index, u_indices[:tmp_patt_len]), end=' ') + print_rec(doc, 0, u_indices, tmp_patt_len) + print(float(score)) + + # Print words from one state + def print_rec(doc, u_index, u_indices, tmp_patt_len): + doc_index = u_indices[u_index] + print(colored(doc[doc_index], 'red'), end='_MP ') + + u_index += 1 + + if u_index == tmp_patt_len: + return + + doc_index += 1 + + while (doc_index < u_indices[u_index]): + print(doc[doc_index], end='_SL ') + doc_index += 1 + + print_rec(doc, u_index, u_indices, tmp_patt_len) + + # each triplet is composed of [0] the pattern score + # [1] the pattern traces (i.e., the indices of the main paths in the given document) + # [2] the document itself + for (i, triplet) in enumerate(top_traces): + print_traces(i+1, triplet[0], triplet[1], triplet[2], tmp_patt_len) + + +def main_test(args): + model, datasets, labels, emb_layer = main_init(args) + + batched_datasets = [] + batched_labels = [] + for [dataset, label] in zip(datasets, labels): + d, l = dataloader.create_batches( + dataset, label, + args.batch_size, + emb_layer.word2id, + sort=True, + gpu=args.gpu, + sos=SOS, + eos=EOS + ) + batched_datasets.append(d) + batched_labels.append(l) + + + if args.gpu: + state_dict = torch.load(args.input_model) + else: + state_dict = torch.load(args.input_model, map_location=lambda storage, loc: storage) + + model.load_state_dict(state_dict) + + if args.gpu: + model.cuda() + + + names = ['train', 'valid', 'test'] + + + errs = [eval_model(0, model, d, l) for [d,l] in zip(batched_datasets, batched_labels)] + + for [name, err] in zip(names, errs): + sys.stdout.write("{}_err: {:.6f}\n".format(name, err)) + + sys.stdout.flush() + return errs + +def main_init(args): + if args.seed: + np.random.seed(args.seed) + torch.manual_seed(args.seed) train_X, train_Y, valid_X, valid_Y, test_X, test_Y = dataloader.read_SST(args.path) data = train_X + valid_X + test_X - embs = dataloader.load_embedding(args.embedding) + if args.loaded_embedding: + embs = args.loaded_embedding + else: + embs = dataloader.load_embedding(args.embedding) emb_layer = modules.EmbeddingLayer( - args.d, data, + data, fix_emb=args.fix_embedding, sos=SOS, eos=EOS, @@ -187,19 +628,21 @@ def main(args): ) nclasses = max(train_Y) + 1 - random_perm = list(range(len(train_X))) + + model = Model(args, emb_layer, nclasses) + + return model, [train_X, valid_X, test_X], [train_Y, valid_Y, test_Y], emb_layer + + +def main(args): + logging_file = init_logging(args) + model, datasets, labels, emb_layer = main_init(args) + + random_perm = list(range(len(datasets[0]))) np.random.shuffle(random_perm) + valid_x, valid_y = dataloader.create_batches( - valid_X, valid_Y, - args.batch_size, - emb_layer.word2id, - sort=True, - gpu=args.gpu, - sos=SOS, - eos=EOS - ) - test_x, test_y = dataloader.create_batches( - test_X, test_Y, + datasets[1], labels[1], args.batch_size, emb_layer.word2id, sort=True, @@ -208,7 +651,6 @@ def main(args): eos=EOS ) - model = Model(args, emb_layer, nclasses) if args.gpu: model.cuda() @@ -227,15 +669,16 @@ def main(args): weight_decay=args.weight_decay ) - scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=args.lr_patience, verbose=True) + scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=args.lr_schedule_decay, patience=args.lr_patience, verbose=True) best_valid = 1e+8 - test_err = 1e+8 unchanged = 0 + + for epoch in range(args.max_epoch): np.random.shuffle(random_perm) train_x, train_y = dataloader.create_batches( - train_X, train_Y, + datasets[0], labels[0], args.batch_size, emb_layer.word2id, perm=random_perm, @@ -244,13 +687,12 @@ def main(args): sos=SOS, eos=EOS ) - best_valid, unchanged, test_err, stop = train_model( + best_valid, unchanged, stop = train_model( epoch, model, optimizer, train_x, train_y, valid_x, valid_y, - test_x, test_y, - best_valid, test_err, - unchanged, scheduler + best_valid, + unchanged, scheduler, logging_file ) if stop: @@ -259,14 +701,20 @@ def main(args): if args.lr_decay > 0: optimizer.param_groups[0]["lr"] *= args.lr_decay + if args.output_dir is not None: + of = os.path.join(args.output_dir, "best_model.pth") + print("Writing model to", of) + torch.save(model.state_dict(), of) - sys.stdout.write("best_valid: {:.6f}\n".format( - best_valid - )) - sys.stdout.write("test_err: {:.6f}\n".format( - test_err - )) + + sys.stdout.write("best_valid: {:.6f}\n".format(best_valid)) +# sys.stdout.write("test_err: {:.6f}\n".format(test_err)) sys.stdout.flush() + logging_file.write("best_valid: {:.6f}\n".format(best_valid)) +# logging_file.write("test_err: {:.6f}\n".format(test_err)) + logging_file.close() + return best_valid#, test_err + def str2bool(v): @@ -278,7 +726,7 @@ def str2bool(v): raise argparse.ArgumentTypeError("Boolean value expected.") -if __name__ == "__main__": +def parse_args(): argparser = argparse.ArgumentParser(sys.argv[0], conflict_handler="resolve") argparser.add_argument("--seed", type=int, default=31415) argparser.add_argument("--model", type=str, default="rrnn") @@ -293,7 +741,7 @@ def str2bool(v): help="if using pretrained embeddings, fix them or not during training") argparser.add_argument("--batch_size", "--batch", type=int, default=32) argparser.add_argument("--max_epoch", type=int, default=100) - argparser.add_argument("--d", type=int, default=256) + argparser.add_argument("--d_out", type=int, default=256) argparser.add_argument("--dropout", type=float, default=0.2, help="dropout intra RNN layers") argparser.add_argument("--embed_dropout", type=float, default=0.2, @@ -311,6 +759,12 @@ def str2bool(v): argparser.add_argument("--clip_grad", type=float, default=5) args = argparser.parse_args() + return args + + + +if __name__ == "__main__": + args = parse_args() print(args) sys.stdout.flush() diff --git a/cuda/rrnn.py b/cuda/bigram_rrnn.py similarity index 99% rename from cuda/rrnn.py rename to cuda/bigram_rrnn.py index 2f12294..6eb7533 100644 --- a/cuda/rrnn.py +++ b/cuda/bigram_rrnn.py @@ -1,4 +1,4 @@ -RRNN = """ +BIGRAM_RRNN = """ extern "C" { __global__ void rrnn_fwd( diff --git a/cuda/rrnn_semiring.py b/cuda/bigram_rrnn_semiring.py similarity index 99% rename from cuda/rrnn_semiring.py rename to cuda/bigram_rrnn_semiring.py index dae3295..47b98f2 100644 --- a/cuda/rrnn_semiring.py +++ b/cuda/bigram_rrnn_semiring.py @@ -1,4 +1,4 @@ -RRNN_SEMIRING = """ +BIGRAM_RRNN_SEMIRING = """ extern "C" { __global__ void rrnn_semiring_fwd( diff --git a/cuda/fourgram_rrnn.py b/cuda/fourgram_rrnn.py new file mode 100644 index 0000000..3889fa9 --- /dev/null +++ b/cuda/fourgram_rrnn.py @@ -0,0 +1,185 @@ +FOURGRAM_RRNN = """ + +extern "C" { + __global__ void rrnn_fwd( + const float * __restrict__ u, + const float * __restrict__ c1_init, + const float * __restrict__ c2_init, + const float * __restrict__ c3_init, + const float * __restrict__ c4_init, + const int len, + const int batch, + const int dim, + const int k, + float * __restrict__ c1, + float * __restrict__ c2, + float * __restrict__ c3, + float * __restrict__ c4, + int semiring_type) { + int ncols = batch*dim; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= ncols) return; + int ncols_u = ncols*k; + + const float *up = u + (col*k); + float *c1p = c1 + col; + float *c2p = c2 + col; + float *c3p = c3 + col; + float *c4p = c4 + col; + float cur_c1 = *(c1_init + col); + float cur_c2 = *(c2_init + col); + float cur_c3 = *(c3_init + col); + float cur_c4 = *(c4_init + col); + + for (int row = 0; row < len; ++row) { + float u1 = *(up); + float u2 = *(up+1); + float u3 = *(up+2); + float u4 = *(up+3); + + float forget1 = *(up+4); + float forget2 = *(up+5); + float forget3 = *(up+6); + float forget4 = *(up+7); + + float prev_c1 = cur_c1; + float prev_c2 = cur_c2; + float prev_c3 = cur_c3; + cur_c1 = cur_c1 * forget1 + u1; + cur_c2 = cur_c2 * forget2 + (prev_c1) * u2; + cur_c3 = cur_c3 * forget3 + (prev_c2) * u3; + cur_c4 = cur_c4 * forget4 + (prev_c3) * u4; + + *c1p = cur_c1; + *c2p = cur_c2; + *c3p = cur_c3; + *c4p = cur_c4; + + up += ncols_u; + c1p += ncols; + c2p += ncols; + c3p += ncols; + c4p += ncols; + } + } + + __global__ void rrnn_bwd( + const float * __restrict__ u, + const float * __restrict__ c1_init, + const float * __restrict__ c2_init, + const float * __restrict__ c3_init, + const float * __restrict__ c4_init, + const float * __restrict__ c1, + const float * __restrict__ c2, + const float * __restrict__ c3, + const float * __restrict__ c4, + const float * __restrict__ grad_c1, + const float * __restrict__ grad_c2, + const float * __restrict__ grad_c3, + const float * __restrict__ grad_c4, + const float * __restrict__ grad_last_c1, + const float * __restrict__ grad_last_c2, + const float * __restrict__ grad_last_c3, + const float * __restrict__ grad_last_c4, + const int len, + const int batch, + const int dim, + const int k, + float * __restrict__ grad_u, + float * __restrict__ grad_c1_init, + float * __restrict__ grad_c2_init, + float * __restrict__ grad_c3_init, + float * __restrict__ grad_c4_init, + int semiring_type) { + int ncols = batch*dim; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= ncols) return; + + int ncols_u = ncols*k; + + float cur_c1 = *(grad_last_c1 + col); + float cur_c2 = *(grad_last_c2 + col); + float cur_c3 = *(grad_last_c3 + col); + float cur_c4 = *(grad_last_c4 + col); + + const float *up = u + (col*k) + (len-1)*ncols_u; + const float *c1p = c1 + col + (len-1)*ncols; + const float *c2p = c2 + col + (len-1)*ncols; + const float *c3p = c3 + col + (len-1)*ncols; + const float *c4p = c4 + col + (len-1)*ncols; + + const float *gc1p = grad_c1 + col + (len-1)*ncols; + const float *gc2p = grad_c2 + col + (len-1)*ncols; + const float *gc3p = grad_c3 + col + (len-1)*ncols; + const float *gc4p = grad_c4 + col + (len-1)*ncols; + float *gup = grad_u + (col*k) + (len-1)*ncols_u; + + for (int row = len-1; row >= 0; --row) { + float u1 = *(up); + float u2 = *(up+1); + float u3 = *(up+2); + float u4 = *(up+3); + float forget1 = *(up+4); + float forget2 = *(up+5); + float forget3 = *(up+6); + float forget4 = *(up+7); + + const float c1_val = *c1p; + const float c2_val = *c2p; + const float c3_val = *c3p; + const float c4_val = *c4p; + + const float prev_c1_val = (row>0) ? (*(c1p-ncols)) : (*(c1_init+col)); + const float prev_c2_val = (row>0) ? (*(c2p-ncols)) : (*(c2_init+col)); + const float prev_c3_val = (row>0) ? (*(c3p-ncols)) : (*(c3_init+col)); + const float prev_c4_val = (row>0) ? (*(c4p-ncols)) : (*(c4_init+col)); + + const float gc1 = *(gc1p) + cur_c1; + const float gc2 = *(gc2p) + cur_c2; + const float gc3 = *(gc3p) + cur_c3; + const float gc4 = *(gc4p) + cur_c4; + + float gu1 = gc1; + *(gup) = gu1; + float gforget1 = gc1*prev_c1_val; + *(gup+4) = gforget1; + + float gu2 = gc2*(prev_c1_val); + *(gup+1) = gu2; + float gforget2 = gc2*prev_c2_val; + *(gup+5) = gforget2; + + float gu3 = gc3*(prev_c2_val); + *(gup+2) = gu3; + float gforget3 = gc3*prev_c3_val; + *(gup+6) = gforget3; + + float gu4 = gc4*(prev_c3_val); + *(gup+3) = gu4; + float gforget4 = gc4*prev_c4_val; + *(gup+7) = gforget4; + + cur_c1 = gc1 * forget1 + gc2 * u2; + cur_c2 = gc2 * forget2 + gc3 * u3; + cur_c3 = gc3 * forget3 + gc4 * u4; + cur_c4 = gc4 * forget4; + + up -= ncols_u; + c1p -= ncols; + c2p -= ncols; + c3p -= ncols; + c4p -= ncols; + gup -= ncols_u; + gc1p -= ncols; + gc2p -= ncols; + gc3p -= ncols; + gc4p -= ncols; + } + + *(grad_c1_init + col) = cur_c1; + *(grad_c2_init + col) = cur_c2; + *(grad_c3_init + col) = cur_c3; + *(grad_c4_init + col) = cur_c4; + } +} +""" diff --git a/cuda/onegram_rrnn.py b/cuda/onegram_rrnn.py new file mode 100644 index 0000000..c5cd080 --- /dev/null +++ b/cuda/onegram_rrnn.py @@ -0,0 +1,89 @@ +ONEGRAM_RRNN = """ + +extern "C" { + __global__ void rrnn_fwd( + const float * __restrict__ u, + const float * __restrict__ c1_init, + const int len, + const int batch, + const int dim, + const int k, + float * __restrict__ c1, + int semiring_type) { + int ncols = batch*dim; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= ncols) return; + int ncols_u = ncols*k; + + const float *up = u + (col*k); + float *c1p = c1 + col; + float cur_c1 = *(c1_init + col); + + for (int row = 0; row < len; ++row) { + float u1 = *(up); + + float forget1 = *(up+1); + + cur_c1 = cur_c1 * forget1 + u1; + + *c1p = cur_c1; + + up += ncols_u; + c1p += ncols; + } + } + + __global__ void rrnn_bwd( + const float * __restrict__ u, + const float * __restrict__ c1_init, + const float * __restrict__ c1, + const float * __restrict__ grad_c1, + const float * __restrict__ grad_last_c1, + const int len, + const int batch, + const int dim, + const int k, + float * __restrict__ grad_u, + float * __restrict__ grad_c1_init, + int semiring_type) { + int ncols = batch*dim; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= ncols) return; + + int ncols_u = ncols*k; + + float cur_c1 = *(grad_last_c1 + col); + + const float *up = u + (col*k) + (len-1)*ncols_u; + const float *c1p = c1 + col + (len-1)*ncols; + + const float *gc1p = grad_c1 + col + (len-1)*ncols; + float *gup = grad_u + (col*k) + (len-1)*ncols_u; + + for (int row = len-1; row >= 0; --row) { + float u1 = *(up); + float forget1 = *(up+1); + + const float c1_val = *c1p; + + const float prev_c1_val = (row>0) ? (*(c1p-ncols)) : (*(c1_init+col)); + + const float gc1 = *(gc1p) + cur_c1; + + float gu1 = gc1; + *(gup) = gu1; + float gforget1 = gc1*prev_c1_val; + *(gup+1) = gforget1; + + cur_c1 = gc1 * forget1; + + up -= ncols_u; + c1p -= ncols; + gup -= ncols_u; + gc1p -= ncols; + } + + *(grad_c1_init + col) = cur_c1; + } +} +""" diff --git a/cuda/threegram_rrnn.py b/cuda/threegram_rrnn.py new file mode 100644 index 0000000..b8f78b8 --- /dev/null +++ b/cuda/threegram_rrnn.py @@ -0,0 +1,153 @@ +THREEGRAM_RRNN = """ + +extern "C" { + __global__ void rrnn_fwd( + const float * __restrict__ u, + const float * __restrict__ c1_init, + const float * __restrict__ c2_init, + const float * __restrict__ c3_init, + const int len, + const int batch, + const int dim, + const int k, + float * __restrict__ c1, + float * __restrict__ c2, + float * __restrict__ c3, + int semiring_type) { + int ncols = batch*dim; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= ncols) return; + int ncols_u = ncols*k; + + const float *up = u + (col*k); + float *c1p = c1 + col; + float *c2p = c2 + col; + float *c3p = c3 + col; + float cur_c1 = *(c1_init + col); + float cur_c2 = *(c2_init + col); + float cur_c3 = *(c3_init + col); + + for (int row = 0; row < len; ++row) { + float u1 = *(up); + float u2 = *(up+1); + float u3 = *(up+2); + + float forget1 = *(up+3); + float forget2 = *(up+4); + float forget3 = *(up+5); + + float prev_c1 = cur_c1; + float prev_c2 = cur_c2; + cur_c1 = cur_c1 * forget1 + u1; + cur_c2 = cur_c2 * forget2 + (prev_c1) * u2; + cur_c3 = cur_c3 * forget3 + (prev_c2) * u3; + + *c1p = cur_c1; + *c2p = cur_c2; + *c3p = cur_c3; + + up += ncols_u; + c1p += ncols; + c2p += ncols; + c3p += ncols; + } + } + + __global__ void rrnn_bwd( + const float * __restrict__ u, + const float * __restrict__ c1_init, + const float * __restrict__ c2_init, + const float * __restrict__ c3_init, + const float * __restrict__ c1, + const float * __restrict__ c2, + const float * __restrict__ c3, + const float * __restrict__ grad_c1, + const float * __restrict__ grad_c2, + const float * __restrict__ grad_c3, + const float * __restrict__ grad_last_c1, + const float * __restrict__ grad_last_c2, + const float * __restrict__ grad_last_c3, + const int len, + const int batch, + const int dim, + const int k, + float * __restrict__ grad_u, + float * __restrict__ grad_c1_init, + float * __restrict__ grad_c2_init, + float * __restrict__ grad_c3_init, + int semiring_type) { + int ncols = batch*dim; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= ncols) return; + + int ncols_u = ncols*k; + + float cur_c1 = *(grad_last_c1 + col); + float cur_c2 = *(grad_last_c2 + col); + float cur_c3 = *(grad_last_c3 + col); + + const float *up = u + (col*k) + (len-1)*ncols_u; + const float *c1p = c1 + col + (len-1)*ncols; + const float *c2p = c2 + col + (len-1)*ncols; + const float *c3p = c3 + col + (len-1)*ncols; + + const float *gc1p = grad_c1 + col + (len-1)*ncols; + const float *gc2p = grad_c2 + col + (len-1)*ncols; + const float *gc3p = grad_c3 + col + (len-1)*ncols; + float *gup = grad_u + (col*k) + (len-1)*ncols_u; + + for (int row = len-1; row >= 0; --row) { + float u1 = *(up); + float u2 = *(up+1); + float u3 = *(up+2); + float forget1 = *(up+3); + float forget2 = *(up+4); + float forget3 = *(up+5); + + const float c1_val = *c1p; + const float c2_val = *c2p; + const float c3_val = *c3p; + + const float prev_c1_val = (row>0) ? (*(c1p-ncols)) : (*(c1_init+col)); + const float prev_c2_val = (row>0) ? (*(c2p-ncols)) : (*(c2_init+col)); + const float prev_c3_val = (row>0) ? (*(c3p-ncols)) : (*(c3_init+col)); + + const float gc1 = *(gc1p) + cur_c1; + const float gc2 = *(gc2p) + cur_c2; + const float gc3 = *(gc3p) + cur_c3; + + float gu1 = gc1; + *(gup) = gu1; + float gforget1 = gc1*prev_c1_val; + *(gup+3) = gforget1; + + float gu2 = gc2*(prev_c1_val); + *(gup+1) = gu2; + float gforget2 = gc2*prev_c2_val; + *(gup+4) = gforget2; + + float gu3 = gc3*(prev_c2_val); + *(gup+2) = gu3; + float gforget3 = gc3*prev_c3_val; + *(gup+5) = gforget3; + + cur_c1 = gc1 * forget1 + gc2 * u2; + cur_c2 = gc2 * forget2 + gc3 * u3; + cur_c3 = gc3 * forget3; + + up -= ncols_u; + c1p -= ncols; + c2p -= ncols; + c3p -= ncols; + gup -= ncols_u; + gc1p -= ncols; + gc2p -= ncols; + gc3p -= ncols; + } + + *(grad_c1_init + col) = cur_c1; + *(grad_c2_init + col) = cur_c2; + *(grad_c3_init + col) = cur_c3; + } +} +""" diff --git a/cuda/twogram_rrnn.py b/cuda/twogram_rrnn.py new file mode 100644 index 0000000..aca1950 --- /dev/null +++ b/cuda/twogram_rrnn.py @@ -0,0 +1,121 @@ +TWOGRAM_RRNN = """ + +extern "C" { + __global__ void rrnn_fwd( + const float * __restrict__ u, + const float * __restrict__ c1_init, + const float * __restrict__ c2_init, + const int len, + const int batch, + const int dim, + const int k, + float * __restrict__ c1, + float * __restrict__ c2, + int semiring_type) { + int ncols = batch*dim; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= ncols) return; + int ncols_u = ncols*k; + + const float *up = u + (col*k); + float *c1p = c1 + col; + float *c2p = c2 + col; + float cur_c1 = *(c1_init + col); + float cur_c2 = *(c2_init + col); + + for (int row = 0; row < len; ++row) { + float u1 = *(up); + float u2 = *(up+1); + + float forget1 = *(up+2); + float forget2 = *(up+3); + + float prev_c1 = cur_c1; + cur_c1 = cur_c1 * forget1 + u1; + cur_c2 = cur_c2 * forget2 + (prev_c1) * u2; + + *c1p = cur_c1; + *c2p = cur_c2; + + up += ncols_u; + c1p += ncols; + c2p += ncols; + } + } + + __global__ void rrnn_bwd( + const float * __restrict__ u, + const float * __restrict__ c1_init, + const float * __restrict__ c2_init, + const float * __restrict__ c1, + const float * __restrict__ c2, + const float * __restrict__ grad_c1, + const float * __restrict__ grad_c2, + const float * __restrict__ grad_last_c1, + const float * __restrict__ grad_last_c2, + const int len, + const int batch, + const int dim, + const int k, + float * __restrict__ grad_u, + float * __restrict__ grad_c1_init, + float * __restrict__ grad_c2_init, + int semiring_type) { + int ncols = batch*dim; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= ncols) return; + + int ncols_u = ncols*k; + + float cur_c1 = *(grad_last_c1 + col); + float cur_c2 = *(grad_last_c2 + col); + + const float *up = u + (col*k) + (len-1)*ncols_u; + const float *c1p = c1 + col + (len-1)*ncols; + const float *c2p = c2 + col + (len-1)*ncols; + + const float *gc1p = grad_c1 + col + (len-1)*ncols; + const float *gc2p = grad_c2 + col + (len-1)*ncols; + float *gup = grad_u + (col*k) + (len-1)*ncols_u; + + for (int row = len-1; row >= 0; --row) { + float u1 = *(up); + float u2 = *(up+1); + float forget1 = *(up+2); + float forget2 = *(up+3); + + const float c1_val = *c1p; + const float c2_val = *c2p; + + const float prev_c1_val = (row>0) ? (*(c1p-ncols)) : (*(c1_init+col)); + const float prev_c2_val = (row>0) ? (*(c2p-ncols)) : (*(c2_init+col)); + + const float gc1 = *(gc1p) + cur_c1; + const float gc2 = *(gc2p) + cur_c2; + + float gu1 = gc1; + *(gup) = gu1; + float gforget1 = gc1*prev_c1_val; + *(gup+2) = gforget1; + + float gu2 = gc2*(prev_c1_val); + *(gup+1) = gu2; + float gforget2 = gc2*prev_c2_val; + *(gup+3) = gforget2; + + cur_c1 = gc1 * forget1 + gc2 * u2; + cur_c2 = gc2 * forget2; + + up -= ncols_u; + c1p -= ncols; + c2p -= ncols; + gup -= ncols_u; + gc1p -= ncols; + gc2p -= ncols; + } + + *(grad_c1_init + col) = cur_c1; + *(grad_c2_init + col) = cur_c2; + } +} +""" diff --git a/cuda/unigram_rrnn.py b/cuda/unigram_rrnn.py new file mode 100644 index 0000000..511ac0a --- /dev/null +++ b/cuda/unigram_rrnn.py @@ -0,0 +1,87 @@ +UNIGRAM_RRNN = """ + +extern "C" { + __global__ void rrnn_fwd( + const float * __restrict__ u, + const float * __restrict__ c_init, + const int len, + const int batch, + const int dim, + const int k, + float * __restrict__ c, + int semiring_type) { + assert (k == K); + int ncols = batch*dim; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= ncols) return; + int ncols_u = ncols*k; + + const float *up = u + (col*k); + float *cp = c + col; + float cur_c = *(c_init + col); + const float eps_val = *(eps + (col%dim)); + + for (int row = 0; row < len; ++row) { + float u = *(up); + float forget = *(up+1); + float prev_c = cur_c; + cur_c = cur_c * forget + u; + *cp = cur_c; + up += ncols_u; + cp += ncols; + } + } + + __global__ void rrnn_bwd( + const float * __restrict__ u, + const float * __restrict__ eps, + const float * __restrict__ c_init, + const float * __restrict__ c, + const float * __restrict__ grad_c, + const float * __restrict__ grad_last_c, + const int len, + const int batch, + const int dim, + const int k, + float * __restrict__ grad_u, + float * __restrict__ grad_c_init, + int semiring_type) { + assert (k == K); + int ncols = batch*dim; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= ncols) return; + + int ncols_u = ncols*k; + + float cur_c = *(grad_last_c + col); + const float *up = u + (col*k) + (len-1)*ncols_u; + const float *cp = c + col + (len-1)*ncols; + + const float *gcp = grad_c + col + (len-1)*ncols; + float *gup = grad_u + (col*k) + (len-1)*ncols_u; + + for (int row = len-1; row >= 0; --row) { + float u = *(up); + float forget = *(up+1); + + const float c_val = *cp; + const float prev_c_val = (row>0) ? (*(cp-ncols)) : (*(c_init+col)); + const float gc = *(gcp) + cur_c; + + float gu = gc; + *(gup) = gu; + float gforget = gc*prev_c_val; + *(gup+1) = gforget; + + cur_c = gc * forget; + + up -= ncols_u; + cp -= ncols; + gup -= ncols_u; + gcp -= ncols; + } + + *(grad_c_init + col) = cur_c; + } +} +""" diff --git a/language_model/train_lm.py b/language_model/train_lm.py index f29e348..c00486b 100644 --- a/language_model/train_lm.py +++ b/language_model/train_lm.py @@ -108,6 +108,7 @@ def __init__(self, words, args): self.n_d, self.n_d, self.depth, + pattern=args.pattern, dropout=args.dropout, rnn_dropout=args.rnn_dropout, use_tanh=use_tanh, @@ -208,7 +209,10 @@ def repackage_hidden(args, hidden): if args.model == "lstm": return (Variable(hidden[0].data), Variable(hidden[1].data)) elif args.model == "rrnn": - return (Variable(hidden[0].data), Variable(hidden[1].data)) + if args.pattern == "bigram": + return (Variable(hidden[0].data), Variable(hidden[1].data)) + elif args.pattern == "unigram": + return Variable(hidden.data) else: assert False @@ -389,6 +393,7 @@ def str2bool(v): argparser.add_argument("--seed", type=int, default=31415) argparser.add_argument("--model", type=str, default="rrnn") argparser.add_argument("--semiring", type=str, default="plus_times") + argparser.add_argument("--pattern", type=str, default="unigram") argparser.add_argument("--use_layer_norm", type=str2bool, default=False) argparser.add_argument("--use_output_gate", type=str2bool, default=False) argparser.add_argument("--activation", type=str, default="none") diff --git a/requirements.txt b/requirements.txt index 3ae9c11..d9a2390 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ pyyaml==3.12 #ipython==6.2.0 tensorboardx==0.8 pynvrtc +termcolor diff --git a/rrnn.py b/rrnn.py index 322f83c..c2a189d 100644 --- a/rrnn.py +++ b/rrnn.py @@ -1,9 +1,162 @@ import sys import torch import torch.nn as nn +import numpy as np from torch.autograd import Variable -def RRNN_Compute_CPU(d, k, semiring, bidirectional=False): +class Max(): + zero = -10000000 + + def select(x, y): + return x >= y + +class Min(): + zero = 10000000 + + def select(x, y): + return x <= y + +def RRNN_Ngram_Compute_CPU(d, k, semiring, bidirectional=False): + + # Compute traces for n_patterns patterns and n_docs documents + class TraceElementParallel(): + def __init__(self, min_max, f, u, prev_traces, i, t, n_patterns, n_docs): + self.u_indices = np.zeros((n_docs, n_patterns, int(k / 2)), dtype=int) + + # Lower triangle in dynamic programming table is impossible (i.e., -inf) + if t < i: + self.score = min_max.zero*np.ones((n_docs, n_patterns)) + return + + # Previous trace values + u_score = np.copy(u.data.numpy()) + f_score = np.copy(f.data.numpy()) + + # If i > 0, including history in computation of u_score and u_indices + if i > 0: + prev_u_indices = prev_traces[i-1].u_indices + u_score *= prev_traces[i-1].score + else: + prev_u_indices = np.zeros((n_docs, n_patterns, int(k / 2)), dtype=int) + + # If t == i, we can't take a forget gate. + if t == i: + prev_f_indices = np.zeros((n_docs, n_patterns, int(k / 2)), dtype=int) + f_score = min_max.zero * np.ones((n_docs, n_patterns)) + # Otherwise, including history of forget gate. + else: + prev_f_indices = prev_traces[i].u_indices + f_score *= prev_traces[i].score + + assert((not np.isnan(u_score).any()) and (not np.isnan(f_score).any())) + + # Dynamic program selection + selected = min_max.select(u_score, f_score) + not_selected = 1 - selected + + # Equivalent to np.maximum(u_score, f_score) (or minimum) + self.score = selected * u_score + not_selected * f_score + + # A fancy way of selecting the previous indices based on the selection criterion above. + prevs = np.expand_dims(selected, 2) * prev_u_indices + \ + np.expand_dims(not_selected, 2) * prev_f_indices + + # Updating u_indices with history (deep copy!) + self.u_indices[:, :, :i+1] = np.copy(prevs[:, :, :i+1]) + + # In the cases where u was selected, updating u_indices with current time step. + self.u_indices[selected, i] = t + + + def rrnn_compute_cpu(u, cs_init=None, eps=None, keep_trace=False, min_max=Max): + assert eps is None, "haven't implemented epsilon steps with arbitrary n-grams. Please set command line param to False." + bidir = 2 if bidirectional else 1 + assert u.size(-1) == k + length, batch = u.size(0), u.size(1) + + for i in range(len(cs_init)): + cs_init[i] = cs_init[i].contiguous().view(batch, bidir, d) + + us = [] + for i in range(0,int(k/2)): + us.append(u[..., i]) + forgets = [] + for i in range(int(k/2), k): + forgets.append(u[..., i]) + + cs_final = None + css = None + all_traces = None + + if not keep_trace: + cs_final = [[] for i in range(int(k/2))] + + css = [Variable(u.data.new(length, batch, bidir, d)) for i in range(int(k/2))] + + + for di in range(bidir): + if di == 0: + time_seq = range(length) + else: + time_seq = range(length - 1, -1, -1) + + if keep_trace: + prev_traces = [None for i in range(len(cs_init))] + else: + cs_prev = [cs_init[i][:, di, :] for i in range(len(cs_init))] + + # input + for t in time_seq: + # ind = 0 + if keep_trace: + # Traces of all pattern states in current time step + all_traces = [] + else: + cs_t = [] + + # States of pattern + for i in range(len(cs_init)): + if keep_trace: + all_traces.append( + TraceElementParallel(min_max, forgets[i][t, :, di, :], us[i][t, :, di, :], prev_traces, i, t, + us[i].size()[3], u.size()[1]) + ) + else: + first_term = cs_prev[i] * forgets[i][t, :, di, :] + second_term = us[i][t, :, di, :] + + if i > 0: + second_term = second_term * cs_prev[i-1] + + cs_t.append(first_term + second_term) + + + if keep_trace: + prev_traces = all_traces + else: + cs_prev = cs_t + + for i in range(len(cs_prev)): + css[i][t,:,di,:] = cs_t[i] + + if not keep_trace: + for i in range(len(cs_prev)): + cs_final[i].append(cs_t[i]) + + if not keep_trace: + for i in range(len(cs_final)): + cs_final[i] = torch.stack(cs_final[i], dim=1).view(batch, -1) + + return css, cs_final, all_traces + if semiring.type == 0: + # plus times + return rrnn_compute_cpu + else: + assert False, "OTHER SEMIRINGS NOT IMPLEMENTED!" + + + +def RRNN_Bigram_Compute_CPU(d, k, semiring, bidirectional=False): """CPU version of the core RRNN computation. Has the same interface as RRNN_Compute_GPU() but is a regular Python function @@ -55,7 +208,7 @@ def rrnn_semiring_compute_cpu(u, c1_init=None, c2_init=None, eps=None): return c1s, c2s, \ torch.stack(c1_final, dim=1).view(batch, -1), \ torch.stack(c2_final, dim=1).view(batch, -1) - + def rrnn_compute_cpu(u, c1_init=None, c2_init=None, eps=None): bidir = 2 if bidirectional else 1 assert u.size(-1) == k @@ -83,7 +236,10 @@ def rrnn_compute_cpu(u, c1_init=None, c2_init=None, eps=None): for t in time_seq: c1_t = c1_prev* forget1[t, :, di, :] + u1[t, :, di, :] - tmp = eps[di, :] + c1_prev + if eps is not None: + tmp = eps[di, :] + c1_prev + else: + tmp = c1_prev c2_t = c2_prev * forget2[t, :, di, :] + tmp * u2[t, :, di, :] c1_prev, c2_prev = c1_t, c2_t c1s[t,:,di,:], c2s[t,:,di,:] = c1_t, c2_t @@ -102,12 +258,50 @@ def rrnn_compute_cpu(u, c1_init=None, c2_init=None, eps=None): # otehrs return rrnn_semiring_compute_cpu +def RRNN_Unigram_Compute_CPU(d, k, semiring, bidirectional=False): + + def rrnn_compute_cpu(u, c_init=None): + bidir = 2 if bidirectional else 1 + assert u.size(-1) == k + length, batch = u.size(0), u.size(1) + if c_init is None: + assert False + else: + c_init = c_init.contiguous().view(batch, bidir, d) + + u, forget = u[..., 0], u[..., 1] + + c_final = [] + cs = Variable(u.data.new(length, batch, bidir, d)) + + for di in range(bidir): + if di == 0: + time_seq = range(length) + else: + time_seq = range(length - 1, -1, -1) + + c_prev = c_init[:, di, :] + for t in time_seq: + c_t = c_prev * forget[t, :, di, :] + u[t, :, di, :] + c_prev = c_t + cs[t, :, di, :] = c_t + c_final.append(c_t) + + return cs, torch.stack(c_final, dim=1).view(batch, -1) + + if semiring.type == 0: + # plus times + return rrnn_compute_cpu + else: + assert False + class RRNNCell(nn.Module): def __init__(self, semiring, n_in, n_out, + pattern="bigram", dropout=0.2, rnn_dropout=0.2, bidirectional=False, @@ -116,10 +310,15 @@ def __init__(self, use_selu=0, weight_norm=False, index=-1, - use_output_gate=True): + use_output_gate=True, + use_rho=False, + rho_sum_to_one=False, + use_last_cs=False, + use_epsilon_steps=True): super(RRNNCell, self).__init__() - assert (n_out % 2) == 0 + #assert (n_out % 2) == 0 self.semiring = semiring + self.pattern = pattern self.n_in = n_in self.n_out = n_out self.rnn_dropout = rnn_dropout @@ -130,6 +329,10 @@ def __init__(self, self.index = index self.activation_type = 0 self.use_output_gate = use_output_gate # borrowed from qrnn + self.use_rho = use_rho + self.rho_sum_to_one = rho_sum_to_one + self.use_last_cs = use_last_cs + self.use_epsilon_steps = use_epsilon_steps if use_tanh: self.activation_type = 1 elif use_relu: @@ -139,30 +342,47 @@ def __init__(self, # basic: in1, in2, f1, f2 # optional: output. - self.k = 5 if self.use_output_gate else 4 - self.n_bias = 5 if self.use_output_gate else 4 - self.size_per_dir = n_out*self.k + if self.pattern == "bigram": + self.k = 5 if self.use_output_gate else 4 + elif self.pattern == "unigram": + self.k = 3 if self.use_output_gate else 2 + else: + # it should be of the form "4-gram" + # should probably implement epsilon stuff, as in bigram + ngram = int(self.pattern.split("-")[0]) + self.k = 2 * (ngram) + + + if self.pattern != "unigram" and self.pattern != "1-gram": + if self.use_rho: + self.bias_final = nn.Parameter(torch.Tensor(self.bidir*n_out*int(self.k/2))) + if self.use_epsilon_steps: + self.bias_eps = nn.Parameter(torch.Tensor(self.bidir*n_out)) + + self.size_per_dir = n_out*self.k self.weight = nn.Parameter(torch.Tensor( n_in, self.size_per_dir*self.bidir )) self.bias = nn.Parameter(torch.Tensor( - n_out*self.n_bias*self.bidir + self.size_per_dir*self.bidir )) - self.bias_eps = nn.Parameter(torch.Tensor(self.bidir*n_out)) - self.bias_final = nn.Parameter(torch.Tensor(self.bidir*n_out*2)) self.init_weights() + def init_weights(self, rescale=True): val_range = (6.0 / (self.n_in + self.n_out)) ** 0.5 self.weight.data.uniform_(-val_range, val_range) # initialize bias self.bias.data.zero_() - self.bias_eps.data.zero_() - self.bias_final.data.zero_() - n_out = self.n_out + + if self.pattern != "unigram" and self.pattern != "1-gram": + if self.use_rho: + self.bias_final.data.zero_() + if self.use_epsilon_steps: + self.bias_eps.data.zero_() self.scale_x = 1 if not rescale: @@ -177,6 +397,7 @@ def init_weights(self, rescale=True): if self.weight_norm: self.init_weight_norm() + def init_weight_norm(self): weight_in = self.weight.data g = weight_in.norm(2, 0) @@ -275,7 +496,8 @@ def semiring_forward(self, input, init_hidden=None): gcs = self.calc_activation(cs).view(length, batch, bidir, n_out) return gcs.view(length, batch, -1), c1_final, c2_final - def real_forward(self, input, init_hidden=None): + + def real_bigram_forward(self, input, init_hidden=None): assert input.dim() == 2 or input.dim() == 3 n_in, n_out = self.n_in, self.n_out length, batch = input.size(0), input.size(-2) @@ -303,7 +525,7 @@ def real_forward(self, input, init_hidden=None): # basic: in1, in2, f1, f2 # optional: output. - bias = self.bias.view(self.n_bias, bidir, n_out) + bias = self.bias.view(self.k, bidir, n_out) _, _, forget_bias1, forget_bias2 = bias[:4, ...] if self.use_output_gate: @@ -317,18 +539,25 @@ def real_forward(self, input, init_hidden=None): u[..., 0] = u_[..., 0] * (1. - u[..., 2]) # input 1 u[..., 1] = u_[..., 1] * (1. - u[..., 3]) # input 2 - + if input.is_cuda: - from rrnn_gpu import RRNN_Compute_GPU - RRNN_Compute = RRNN_Compute_GPU(n_out, 4, self.semiring, self.bidirectional) + from rrnn_gpu import RRNN_Bigram_Compute_GPU + RRNN_Compute = RRNN_Bigram_Compute_GPU(n_out, 4, self.semiring, self.bidirectional) else: - RRNN_Compute = RRNN_Compute_CPU(n_out, 4, self.semiring, self.bidirectional) + RRNN_Compute = RRNN_Bigram_Compute_CPU(n_out, 4, self.semiring, self.bidirectional) + + if self.use_epsilon_steps: + eps = self.bias_eps.view(bidir, n_out).sigmoid() + else: + eps = None - eps = self.bias_eps.view(bidir, n_out).sigmoid() c1s, c2s, c1_final, c2_final= RRNN_Compute(u, c1_init, c2_init, eps) - rho = self.bias_final.view(bidir, n_out, 2).sigmoid() - cs = c1s * rho[...,0] + c2s * rho[...,1] + if self.use_rho: + rho = self.bias_final.view(bidir, n_out, 2).sigmoid() + cs = c1s * rho[...,0] + c2s * rho[...,1] + else: + cs = c1s + c2s if self.use_output_gate: gcs = self.calc_activation(output*cs) @@ -337,11 +566,183 @@ def real_forward(self, input, init_hidden=None): return gcs.view(length, batch, -1), c1_final, c2_final - def forward(self, input, init_hidden=None): + def real_unigram_forward(self, input, init_hidden=None): + assert input.dim() == 2 or input.dim() == 3 + n_in, n_out = self.n_in, self.n_out + length, batch = input.size(0), input.size(-2) + bidir = self.bidir + if init_hidden is None: + size = (batch, n_out * bidir) + c_init = Variable(input.data.new(*size).zero_()) + else: + c_init = init_hidden + + if self.training and (self.rnn_dropout>0): + mask = self.get_dropout_mask_((1, batch, n_in), self.rnn_dropout) + x = input * mask.expand_as(input) + else: + x = input + + x_2d = x if x.dim() == 2 else x.contiguous().view(-1, n_in) + + weight_in = self.weight if not self.weight_norm else self.apply_weight_norm() + u_ = x_2d.mm(weight_in) + u_ = u_.view(length, batch, bidir, n_out, self.k) + + + # basic: in, f + # optional: output. + bias = self.bias.view(self.k, bidir, n_out) + + _, forget_bias = bias[:2, ...] + if self.use_output_gate: + output_bias = bias[3, ...] + output = (u_[..., 3] + output_bias).sigmoid() + + u = Variable(u_.data.new(length, batch, bidir, n_out, 2)) + + u[..., 1] = (u_[..., 1] + forget_bias).sigmoid() # forget + u[..., 0] = u_[..., 0] * (1. - u[..., 1]) # input + + + if input.is_cuda: + from rrnn_gpu import RRNN_Compute_GPU + RRNN_Compute = RRNN_Unigram_Compute_GPU(n_out, 2, self.semiring, self.bidirectional) + else: + RRNN_Compute = RRNN_Unigram_Compute_CPU(n_out, 2, self.semiring, self.bidirectional) + + cs, c_final = RRNN_Compute(u, c_init) + + if self.use_output_gate: + gcs = self.calc_activation(output*cs) + else: + gcs = self.calc_activation(cs) + + return gcs.view(length, batch, -1), c_final + + def real_ngram_forward(self, input, init_hidden=None, keep_trace=False, min_max=Max): + assert input.dim() == 3 + n_in, n_out = self.n_in, self.n_out + length, batch = input.size(0), input.size(-2) + bidir = self.bidir + + if init_hidden is None: + size = (batch, n_out * bidir) + cs_init = [] + for i in range(int(self.k/2)): + cs_init.append(Variable(input.data.new(*size).zero_())) + + else: + assert False, "NOT IMPLEMENTED!" + assert (len(init_hidden) == 2) + c1_init, c2_init, = init_hidden + + if self.training and (self.rnn_dropout>0): + mask = self.get_dropout_mask_((1, batch, n_in), self.rnn_dropout) + x = input * mask.expand_as(input) + else: + x = input + + x_2d = x if x.dim() == 2 else x.contiguous().view(-1, n_in) + + weight_in = self.weight if not self.weight_norm else self.apply_weight_norm() + u_ = x_2d.mm(weight_in) + u_ = u_.view(length, batch, bidir, n_out, self.k) + + # optional: output. + bias = self.bias.view(self.k, bidir, n_out) + + u = Variable(u_.data.new(length, batch, bidir, n_out, self.k)) + + for i in range(int(self.k/2),self.k): + forget_bias = bias[i, ...] + u[..., i] = (u_[..., i] + forget_bias).sigmoid() # forget + + for i in range(0, int(self.k/2)): + u[..., i] = u_[..., i] * (1. - u[..., i + int(self.k/2)]) # input + + if input.is_cuda: + + if self.k == 8: + + from rrnn_gpu import RRNN_4gram_Compute_GPU + RRNN_Compute_GPU = RRNN_4gram_Compute_GPU(n_out, self.k, self.semiring, self.bidirectional) + c1s, c2s, c3s, c4s, last_c1, last_c2, last_c3, last_c4 = RRNN_Compute_GPU(u, cs_init[0], cs_init[1], cs_init[2], cs_init[3]) + css = [c1s, c2s, c3s, c4s] + cs_final = [last_c1, last_c2, last_c3, last_c4] + + elif self.k == 6: + from rrnn_gpu import RRNN_3gram_Compute_GPU + RRNN_Compute_GPU = RRNN_3gram_Compute_GPU(n_out, self.k, self.semiring, self.bidirectional) + c1s, c2s, c3s, last_c1, last_c2, last_c3 = RRNN_Compute_GPU(u, cs_init[0], cs_init[1], cs_init[2]) + css = [c1s, c2s, c3s] + cs_final = [last_c1, last_c2, last_c3] + + elif self.k == 4: + from rrnn_gpu import RRNN_2gram_Compute_GPU + RRNN_Compute_GPU = RRNN_2gram_Compute_GPU(n_out, self.k, self.semiring, self.bidirectional) + c1s, c2s, last_c1, last_c2 = RRNN_Compute_GPU(u, cs_init[0], cs_init[1]) + css = [c1s, c2s] + cs_final = [last_c1, last_c2] + + elif self.k == 2: + from rrnn_gpu import RRNN_1gram_Compute_GPU + RRNN_Compute_GPU = RRNN_1gram_Compute_GPU(n_out, self.k, self.semiring, self.bidirectional) + c1s, last_c1 = RRNN_Compute_GPU(u, cs_init[0]) + css = [c1s] + cs_final = [last_c1] + + else: + assert False, "custom cuda kernel only implemented for 1,2,3,4-gram models" + else: + RRNN_Compute = RRNN_Ngram_Compute_CPU(n_out, self.k, self.semiring, self.bidirectional) + css, cs_final, traces = RRNN_Compute(u, cs_init, eps=None, keep_trace=keep_trace, min_max=min_max) + + if keep_trace: + return None, None, traces + + # instead of using \rho to weight the sum, we can give uniform weight. this might be + # more interpretable, as the \rhos might counteract the regularization terms + if self.use_rho: + if self.rho_sum_to_one: + sm = nn.Softmax(dim=2) + rho = sm(self.bias_final.view(bidir, n_out, int(self.k/2))) + else: + rho = self.bias_final.view(bidir, n_out, int(self.k/2)).sigmoid() + css_times_rho = [] + for i in range(len(css)): + css_times_rho.append(css[i] * rho[...,i]) + + cs = sum(css_times_rho) + else: + if self.use_last_cs: + cs = css[-1] + else: + cs = sum(css) + + if self.use_output_gate: + assert False, "THIS HASN'T BEEN IMPLEMENTED YET!" + gcs = self.calc_activation(output*cs) + else: + gcs = self.calc_activation(cs) + + return gcs.view(length, batch, -1), cs_final, traces + + + def forward(self, input, init_hidden=None, keep_trace=False, min_max=Max): + if self.semiring.type == 0: # plus times - return self.real_forward(input=input, init_hidden=init_hidden) + if self.pattern == "bigram": + return self.real_bigram_forward(input=input, init_hidden=init_hidden) + elif self.pattern == "unigram": + return self.real_unigram_forward(input=input, init_hidden=init_hidden) + else: + # it should be of the form "4-gram" + return self.real_ngram_forward(input=input, init_hidden=init_hidden, keep_trace=keep_trace, min_max=min_max) + else: + assert False, "not implemented yet." return self.semiring_forward(input=input, init_hidden=init_hidden) def get_dropout_mask_(self, size, p, rescale=True): @@ -352,12 +753,93 @@ def get_dropout_mask_(self, size, p, rescale=True): return Variable(w.new(*size).bernoulli_(1-p)) +class RRNNLayer(nn.Module): + def __init__(self, + semiring, + n_in, + n_out, + pattern, + dropout=0.2, + rnn_dropout=0.2, + bidirectional=False, + use_tanh=1, + use_relu=0, + use_selu=0, + weight_norm=False, + index=-1, + use_output_gate=True, + use_rho=False, + rho_sum_to_one=False, + use_last_cs=False, + use_epsilon_steps=True): + super(RRNNLayer, self).__init__() + + self.cells = nn.ModuleList() + + assert len(pattern) == len(n_out) + num_cells = len(pattern) + for i in range(num_cells): + if n_out[i] > 0: + one_cell = RRNNCell( + semiring=semiring, + n_in=n_in, + n_out=n_out[i], + pattern=pattern[i], + dropout=dropout, + rnn_dropout=rnn_dropout, + bidirectional=bidirectional, + use_tanh=use_tanh, + use_relu=use_relu, + use_selu=use_selu, + weight_norm=weight_norm, + index=index, + use_output_gate=use_output_gate, + use_rho=use_rho, + rho_sum_to_one=rho_sum_to_one, + use_last_cs=use_last_cs, + use_epsilon_steps=use_epsilon_steps + ) + self.cells.append(one_cell) + + def init_weights(self): + for cell in self.cells: + cell.init_weights() + + def forward(self, input, init_hidden=None, keep_trace=False, min_max=Max): + #import pdb; pdb.set_trace() + + gcs, cs_final, traces = self.cells[0](input, init_hidden, keep_trace, min_max=min_max) + + if keep_trace: + # An array where each element is the traces for all the patterns of one pattern length. + all_traces = [] + + all_traces.append(traces) + + for i, cell in enumerate(self.cells): + if i == 0: + continue + else: + gcs_cur, _, traces = cell(input, init_hidden, keep_trace, min_max=min_max) + + if keep_trace: + all_traces.append(traces) + else: + gcs = torch.cat((gcs, gcs_cur), 2) + #for j in range(len(cs_final)): + # cs_final[j] = torch.cat((cs_final[j], cs_final_cur[j]), 1) + #cs_final = torch.cat(cs_final, cs_final_cur) + + return gcs, None, all_traces + + class RRNN(nn.Module): def __init__(self, semiring, input_size, hidden_size, - num_layers=2, + num_layers, + pattern="bigram", dropout=0.2, rnn_dropout=0.2, bidirectional=False, @@ -366,13 +848,18 @@ def __init__(self, use_selu=0, weight_norm=False, layer_norm=False, - use_output_gate=True): + use_output_gate=True, + use_rho=False, + rho_sum_to_one=False, + use_last_cs=False, + use_epsilon_steps=True): super(RRNN, self).__init__() assert not bidirectional self.semiring = semiring self.input_size = input_size - self.hidden_size = hidden_size + self.hidden_size = [int(one_size) for one_size in hidden_size.split(",")] self.num_layers = num_layers + self.pattern = [one_pattern for one_pattern in pattern.split(",")] self.dropout = dropout self.rnn_dropout = rnn_dropout self.rnn_lst = nn.ModuleList() @@ -380,18 +867,21 @@ def __init__(self, self.bidirectional = bidirectional self.use_layer_norm = layer_norm self.use_wieght_norm = weight_norm - self.out_size = hidden_size * 2 if bidirectional else hidden_size + #self.out_size = hidden_size * 2 if bidirectional else hidden_size + + assert len(self.hidden_size) == len(self.pattern), "each n-gram must have an output size." if use_tanh + use_relu + use_selu > 1: sys.stderr.write("\nWARNING: More than one activation enabled in RRNN" " (tanh: {} relu: {} selu: {})\n".format(use_tanh, use_relu, use_selu) ) - + for i in range(num_layers): - l = RRNNCell( + l = RRNNLayer( semiring=semiring, - n_in=self.input_size if i == 0 else self.out_size, + n_in=self.input_size if i == 0 else sum(self.hidden_size), n_out=self.hidden_size, + pattern=self.pattern, dropout=dropout if i+1 != num_layers else 0., rnn_dropout=rnn_dropout, bidirectional=bidirectional, @@ -400,7 +890,11 @@ def __init__(self, use_selu=use_selu, weight_norm=weight_norm, index=i+1, - use_output_gate=use_output_gate + use_output_gate=use_output_gate, + use_rho=use_rho, + rho_sum_to_one=rho_sum_to_one, + use_last_cs=use_last_cs, + use_epsilon_steps=use_epsilon_steps ) self.rnn_lst.append(l) if layer_norm: @@ -410,18 +904,41 @@ def init_weights(self): for l in self.rnn_lst: l.init_weights() - def forward(self, input, init_hidden=None, return_hidden=True): - assert input.dim() == 3 # (len, batch, n_in) + def unigram_forward(self, input, init_hidden=None, return_hidden=True): + assert input.dim() == 3 # (len, batch, n_in) + if init_hidden is None: + init_hidden = [None for _ in range(self.num_layers)] + else: + for c in init_hidden: + assert c.dim() == 2 + init_hidden = [c.squeeze(0) for c in + init_hidden.chunk(self.num_layers, 0)] + + prevx = input + lstc = [] + for i, rnn in enumerate(self.rnn_lst): + h, c = rnn(prevx, init_hidden[i]) + prevx = self.ln_lst[i](h) if self.use_layer_norm else h + lstc.append(c) + + if return_hidden: + return prevx, torch.stack(lstc) + else: + return prevx + + def bigram_forward(self, input, init_hidden=None, return_hidden=True): + + assert input.dim() == 3 # (len, batch, n_in) if init_hidden is None: init_hidden = [None for _ in range(self.num_layers)] else: for c in init_hidden: assert c.dim() == 3 init_hidden = [(c1.squeeze(0), c2.squeeze(0)) - for c1,c2 in zip( - init_hidden[0].chunk(self.num_layers, 0), - init_hidden[1].chunk(self.num_layers, 0) - )] + for c1, c2 in zip( + init_hidden[0].chunk(self.num_layers, 0), + init_hidden[1].chunk(self.num_layers, 0) + )] prevx = input lstc1, lstc2 = [], [] @@ -437,6 +954,56 @@ def forward(self, input, init_hidden=None, return_hidden=True): return prevx + + def ngram_forward(self, input, init_hidden=None, return_hidden=True, keep_trace=False, min_max=Max): + assert input.dim() == 3 # (len, batch, n_in) + if init_hidden is None: + init_hidden = [None for _ in range(self.num_layers)] + else: + assert False, "THIS IS NOT IMPLEMENTED, I DON'T THINK IT'S NECESSARY FOR CLASSIFICATION" + for c in init_hidden: + assert c.dim() == int(self.k/2) + init_hidden = [(c1.squeeze(0), c2.squeeze(0)) + for c1, c2 in zip( + init_hidden[0].chunk(self.num_layers, 0), + init_hidden[1].chunk(self.num_layers, 0) + )] + + + prevx = input + # ngram used to be a parameter to this method. + #lstcs = [[] for i in range(ngram)] + + first_traces = None + for i, rnn in enumerate(self.rnn_lst): + h, cs, traces = rnn(prevx, init_hidden[i], keep_trace, min_max=min_max) + # Only visualize first layer + if i == 0 and keep_trace: + first_traces = traces + + #for j in range(len(cs)): + # lstcs[j].append(cs[j]) + prevx = self.ln_lst[i](h) if self.use_layer_norm else h + + #stacked_lstcs = [torch.stack(lstcs[i]) for i in range(len(lstcs))] + stacked_lstcs = None + + if return_hidden: + return prevx, stacked_lstcs, first_traces + else: + return prevx + + def forward(self, input, init_hidden=None, return_hidden=True, keep_trace=False, min_max=Max): + if self.pattern == "unigram": + return self.unigram_forward(input, init_hidden, return_hidden) + elif self.pattern == "bigram": + return self.bigram_forward(input, init_hidden, return_hidden) + else: + # it should be of the form "4-gram" + #ngram = int(self.pattern.split("-")[0]) + return self.ngram_forward(input, init_hidden, return_hidden, keep_trace=keep_trace, min_max=min_max) + + class LayerNorm(nn.Module): def __init__(self, features, eps=1e-6): diff --git a/rrnn_gpu.py b/rrnn_gpu.py index 15e253f..4775ce5 100644 --- a/rrnn_gpu.py +++ b/rrnn_gpu.py @@ -5,19 +5,145 @@ from cupy.cuda import function import numpy as np from cuda.utils import * -from cuda.rrnn import * -from cuda.rrnn_semiring import * +from cuda.bigram_rrnn import * +from cuda.bigram_rrnn_semiring import * +from cuda.unigram_rrnn import * +from cuda.fourgram_rrnn import * +from cuda.threegram_rrnn import * +from cuda.twogram_rrnn import * +from cuda.onegram_rrnn import * -class RRNN_Compute_GPU(Function): +class RRNN_Unigram_Compute_GPU(Function): - _RRNN_PROG = Program((UTIL+RRNN+RRNN_SEMIRING).encode("utf-8"), "rrnn_prog.cu".encode()) + #_RRNN_PROG = Program((UTIL + UNIGRAM_RRNN).encode("utf-8"), "rrnn_prog.cu".encode()) + #_RRNN_PTX = _RRNN_PROG.compile() + #_DEVICE2FUNC = {} + + + def __init__(self, d_out, k, semiring, bidirectional=False): + super(RRNN_Unigram_Compute_GPU, self).__init__() + self.semiring = semiring + self.d_out = d_out + self.k = k + self.bidirectional = bidirectional + assert not bidirectional + + + def compile_functions(self): + device = torch.cuda.current_device() + print ("RRNN loaded for gpu {}".format(device)) + mod = function.Module() + mod.load(bytes(self._RRNN_PTX.encode())) + + if self.semiring.type == 0: + fwd_func = mod.get_function("rrnn_fwd") + bwd_func = mod.get_function("rrnn_bwd") + Stream = namedtuple("Stream", ["ptr"]) + current_stream = Stream(ptr=torch.cuda.current_stream().cuda_stream) + self._DEVICE2FUNC[device] = ( + current_stream, fwd_func, bwd_func, + ) + return current_stream, fwd_func, bwd_func + else: + fwd_func = mod.get_function("rrnn_semiring_fwd") + bwd_func = mod.get_function("rrnn_semiring_bwd") + Stream = namedtuple("Stream", ["ptr"]) + current_stream = Stream(ptr=torch.cuda.current_stream().cuda_stream) + self._DEVICE2FUNC[device] = ( + current_stream, fwd_func, bwd_func + ) + return current_stream, fwd_func, bwd_func + + + def get_functions(self): + res = self._DEVICE2FUNC.get(torch.cuda.current_device(), None) + return res if res else self.compile_functions() + + + def forward(self, u, c_init=None): + bidir = 2 if self.bidirectional else 1 + assert u.size(-1) == self.k + length, batch = u.size(0), u.size(1) + dim = self.d_out + ncols = batch*dim*bidir + thread_per_block = min(512, ncols) + num_block = (ncols-1)//thread_per_block+1 + if c_init is None: + assert False + + size = (length, batch, bidir, dim) + cs = u.new(*size) + stream, fwd_func, _ = self.get_functions() + FUNC = fwd_func + FUNC(args=[ + u.contiguous().data_ptr(), + c_init.contiguous().data_ptr(), + np.int32(length), + np.int32(batch), + np.int32(dim), + np.int32(self.k), + cs.data_ptr(), + np.int32(self.semiring.type)], + block = (thread_per_block,1,1), grid = (num_block,1,1), + stream=stream + ) + self.save_for_backward(u, c_init) + self.intermediate_cs = cs + if self.bidirectional: + last_c = torch.cat((cs[-1,:,0,:], cs[0,:,1,:]), dim=1) + else: + last_c = cs[-1,...].view(batch, -1) + return cs, last_c + + + def backward(self, grad_cs, grad_last_c): + bidir = 2 if self.bidirectional else 1 + u, c_init = self.saved_tensors + cs = self.intermediate_cs + length, batch = u.size(0), u.size(1) + dim = self.d_out + ncols = batch*dim*bidir + thread_per_block = min(512, ncols) + num_block = (ncols-1)//thread_per_block+1 + + if c_init is None: + assert False + # init_ = x.new(ncols).zero_() if init is None else init + grad_u = u.new(*u.size()) + grad_init_c = u.new(batch, dim*bidir) + stream, _, bwd_func = self.get_functions() + FUNC = bwd_func + + FUNC(args=[ + u.contiguous().data_ptr(), + c_init.contiguous().data_ptr(), + cs.data_ptr(), + grad_cs.data_ptr(), + grad_last_c.contiguous().data_ptr(), + np.int32(length), + np.int32(batch), + np.int32(dim), + np.int32(self.k), + grad_u.data_ptr(), + grad_init_c.data_ptr(), + np.int32(self.semiring.type)], + block = (thread_per_block,1,1), grid = (num_block,1,1), + stream=stream + ) + + return grad_u, grad_init_c + + +class RRNN_Bigram_Compute_GPU(Function): + + _RRNN_PROG = Program((UTIL + BIGRAM_RRNN + BIGRAM_RRNN_SEMIRING).encode("utf-8"), "rrnn_prog.cu".encode()) _RRNN_PTX = _RRNN_PROG.compile() _DEVICE2FUNC = {} def __init__(self, d_out, k, semiring, bidirectional=False): - super(RRNN_Compute_GPU, self).__init__() + super(RRNN_Bigram_Compute_GPU, self).__init__() self.semiring = semiring self.d_out = d_out self.k = k @@ -143,4 +269,519 @@ def backward(self, grad_c1s, grad_c2s, grad_last_c1, grad_last_c2): stream=stream ) - return grad_u, grad_init_c1, grad_init_c2, grad_eps \ No newline at end of file + return grad_u, grad_init_c1, grad_init_c2, grad_eps + + +class RRNN_1gram_Compute_GPU(Function): + + _RRNN_PROG = Program((UTIL + ONEGRAM_RRNN).encode("utf-8"), "rrnn_prog.cu".encode()) + _RRNN_PTX = _RRNN_PROG.compile() + _DEVICE2FUNC = {} + + + def __init__(self, d_out, k, semiring, bidirectional=False): + super(RRNN_1gram_Compute_GPU, self).__init__() + self.semiring = semiring + self.d_out = d_out + self.k = k + self.bidirectional = bidirectional + assert not bidirectional + + + def compile_functions(self): + device = torch.cuda.current_device() + print ("RRNN loaded for gpu {}".format(device)) + mod = function.Module() + mod.load(bytes(self._RRNN_PTX.encode())) + + if self.semiring.type == 0: + fwd_func = mod.get_function("rrnn_fwd") + bwd_func = mod.get_function("rrnn_bwd") + Stream = namedtuple("Stream", ["ptr"]) + current_stream = Stream(ptr=torch.cuda.current_stream().cuda_stream) + self._DEVICE2FUNC[device] = ( + current_stream, fwd_func, bwd_func, + ) + return current_stream, fwd_func, bwd_func + else: + assert False, "other semirings are not currently implemented." + + def get_functions(self): + res = self._DEVICE2FUNC.get(torch.cuda.current_device(), None) + return res if res else self.compile_functions() + + + def forward(self, u, c1_init=None): + bidir = 2 if self.bidirectional else 1 + assert u.size(-1) == self.k + length, batch = u.size(0), u.size(1) + dim = self.d_out + ncols = batch*dim*bidir + thread_per_block = min(512, ncols) + num_block = (ncols-1)//thread_per_block+1 + if c1_init is None: + assert False + + size = (length, batch, bidir, dim) + c1s = u.new(*size) + stream, fwd_func, _ = self.get_functions() + FUNC = fwd_func + FUNC(args=[ + u.contiguous().data_ptr(), + c1_init.contiguous().data_ptr(), + np.int32(length), + np.int32(batch), + np.int32(dim), + np.int32(self.k), + c1s.data_ptr(), + np.int32(self.semiring.type)], + block = (thread_per_block,1,1), grid = (num_block,1,1), + stream=stream + ) + self.save_for_backward(u, c1_init) + self.intermediate_c1s = c1s + + if self.bidirectional: + assert False, "bidirectionality isn't implemented yet" + else: + last_c1 = c1s[-1,...].view(batch, -1) + return c1s, last_c1 + + + def backward(self, grad_c1s, grad_last_c1): + bidir = 2 if self.bidirectional else 1 + u, c1_init = self.saved_tensors + c1s = self.intermediate_c1s + length, batch = u.size(0), u.size(1) + dim = self.d_out + ncols = batch*dim*bidir + thread_per_block = min(512, ncols) + num_block = (ncols-1)//thread_per_block+1 + + if c1_init is None: + assert False + # init_ = x.new(ncols).zero_() if init is None else init + grad_u = u.new(*u.size()) + grad_init_c1 = u.new(batch, dim*bidir) + stream, _, bwd_func = self.get_functions() + FUNC = bwd_func + + FUNC(args=[ + u.contiguous().data_ptr(), + c1_init.contiguous().data_ptr(), + c1s.data_ptr(), + grad_c1s.data_ptr(), + grad_last_c1.contiguous().data_ptr(), + np.int32(length), + np.int32(batch), + np.int32(dim), + np.int32(self.k), + grad_u.data_ptr(), + grad_init_c1.data_ptr(), + np.int32(self.semiring.type)], + block = (thread_per_block,1,1), grid = (num_block,1,1), + stream=stream + ) + + return grad_u, grad_init_c1 + + +class RRNN_2gram_Compute_GPU(Function): + + _RRNN_PROG = Program((UTIL + TWOGRAM_RRNN).encode("utf-8"), "rrnn_prog.cu".encode()) + _RRNN_PTX = _RRNN_PROG.compile() + _DEVICE2FUNC = {} + + + def __init__(self, d_out, k, semiring, bidirectional=False): + super(RRNN_2gram_Compute_GPU, self).__init__() + self.semiring = semiring + self.d_out = d_out + self.k = k + self.bidirectional = bidirectional + assert not bidirectional + + + def compile_functions(self): + device = torch.cuda.current_device() + print ("RRNN loaded for gpu {}".format(device)) + mod = function.Module() + mod.load(bytes(self._RRNN_PTX.encode())) + + if self.semiring.type == 0: + fwd_func = mod.get_function("rrnn_fwd") + bwd_func = mod.get_function("rrnn_bwd") + Stream = namedtuple("Stream", ["ptr"]) + current_stream = Stream(ptr=torch.cuda.current_stream().cuda_stream) + self._DEVICE2FUNC[device] = ( + current_stream, fwd_func, bwd_func, + ) + return current_stream, fwd_func, bwd_func + else: + assert False, "other semirings are not currently implemented." + + def get_functions(self): + res = self._DEVICE2FUNC.get(torch.cuda.current_device(), None) + return res if res else self.compile_functions() + + + def forward(self, u, c1_init=None, c2_init=None): + bidir = 2 if self.bidirectional else 1 + assert u.size(-1) == self.k + length, batch = u.size(0), u.size(1) + dim = self.d_out + ncols = batch*dim*bidir + thread_per_block = min(512, ncols) + num_block = (ncols-1)//thread_per_block+1 + if c1_init is None: + assert False + + size = (length, batch, bidir, dim) + c1s = u.new(*size) + c2s = u.new(*size) + stream, fwd_func, _ = self.get_functions() + FUNC = fwd_func + FUNC(args=[ + u.contiguous().data_ptr(), + c1_init.contiguous().data_ptr(), + c2_init.contiguous().data_ptr(), + np.int32(length), + np.int32(batch), + np.int32(dim), + np.int32(self.k), + c1s.data_ptr(), + c2s.data_ptr(), + np.int32(self.semiring.type)], + block = (thread_per_block,1,1), grid = (num_block,1,1), + stream=stream + ) + self.save_for_backward(u, c1_init, c2_init) + self.intermediate_c1s, self.intermediate_c2s = c1s, c2s + + if self.bidirectional: + assert False, "bidirectionality isn't implemented yet" + else: + last_c1 = c1s[-1,...].view(batch, -1) + last_c2 = c2s[-1,...].view(batch, -1) + return c1s, c2s, last_c1, last_c2 + + + def backward(self, grad_c1s, grad_c2s, grad_last_c1, grad_last_c2): + bidir = 2 if self.bidirectional else 1 + u, c1_init, c2_init = self.saved_tensors + c1s, c2s = self.intermediate_c1s, self.intermediate_c2s + length, batch = u.size(0), u.size(1) + dim = self.d_out + ncols = batch*dim*bidir + thread_per_block = min(512, ncols) + num_block = (ncols-1)//thread_per_block+1 + + if c1_init is None: + assert False + # init_ = x.new(ncols).zero_() if init is None else init + grad_u = u.new(*u.size()) + grad_init_c1 = u.new(batch, dim*bidir) + grad_init_c2 = u.new(batch, dim*bidir) + stream, _, bwd_func = self.get_functions() + FUNC = bwd_func + + FUNC(args=[ + u.contiguous().data_ptr(), + c1_init.contiguous().data_ptr(), + c2_init.contiguous().data_ptr(), + c1s.data_ptr(), + c2s.data_ptr(), + grad_c1s.data_ptr(), + grad_c2s.data_ptr(), + grad_last_c1.contiguous().data_ptr(), + grad_last_c2.contiguous().data_ptr(), + np.int32(length), + np.int32(batch), + np.int32(dim), + np.int32(self.k), + grad_u.data_ptr(), + grad_init_c1.data_ptr(), + grad_init_c2.data_ptr(), + np.int32(self.semiring.type)], + block = (thread_per_block,1,1), grid = (num_block,1,1), + stream=stream + ) + + return grad_u, grad_init_c1, grad_init_c2 + +class RRNN_3gram_Compute_GPU(Function): + + _RRNN_PROG = Program((UTIL + THREEGRAM_RRNN).encode("utf-8"), "rrnn_prog.cu".encode()) + _RRNN_PTX = _RRNN_PROG.compile() + _DEVICE2FUNC = {} + + + def __init__(self, d_out, k, semiring, bidirectional=False): + super(RRNN_3gram_Compute_GPU, self).__init__() + self.semiring = semiring + self.d_out = d_out + self.k = k + self.bidirectional = bidirectional + assert not bidirectional + + + def compile_functions(self): + device = torch.cuda.current_device() + print ("RRNN loaded for gpu {}".format(device)) + mod = function.Module() + mod.load(bytes(self._RRNN_PTX.encode())) + + if self.semiring.type == 0: + fwd_func = mod.get_function("rrnn_fwd") + bwd_func = mod.get_function("rrnn_bwd") + Stream = namedtuple("Stream", ["ptr"]) + current_stream = Stream(ptr=torch.cuda.current_stream().cuda_stream) + self._DEVICE2FUNC[device] = ( + current_stream, fwd_func, bwd_func, + ) + return current_stream, fwd_func, bwd_func + else: + assert False, "other semirings are not currently implemented." + + def get_functions(self): + res = self._DEVICE2FUNC.get(torch.cuda.current_device(), None) + return res if res else self.compile_functions() + + + def forward(self, u, c1_init=None, c2_init=None, c3_init=None): + bidir = 2 if self.bidirectional else 1 + assert u.size(-1) == self.k + length, batch = u.size(0), u.size(1) + dim = self.d_out + ncols = batch*dim*bidir + thread_per_block = min(512, ncols) + num_block = (ncols-1)//thread_per_block+1 + if c1_init is None: + assert False + + size = (length, batch, bidir, dim) + c1s = u.new(*size) + c2s = u.new(*size) + c3s = u.new(*size) + stream, fwd_func, _ = self.get_functions() + FUNC = fwd_func + FUNC(args=[ + u.contiguous().data_ptr(), + c1_init.contiguous().data_ptr(), + c2_init.contiguous().data_ptr(), + c3_init.contiguous().data_ptr(), + np.int32(length), + np.int32(batch), + np.int32(dim), + np.int32(self.k), + c1s.data_ptr(), + c2s.data_ptr(), + c3s.data_ptr(), + np.int32(self.semiring.type)], + block = (thread_per_block,1,1), grid = (num_block,1,1), + stream=stream + ) + self.save_for_backward(u, c1_init, c2_init, c3_init) + self.intermediate_c1s, self.intermediate_c2s, self.intermediate_c3s = c1s, c2s, c3s + + if self.bidirectional: + assert False, "bidirectionality isn't implemented yet" + else: + last_c1 = c1s[-1,...].view(batch, -1) + last_c2 = c2s[-1,...].view(batch, -1) + last_c3 = c3s[-1,...].view(batch, -1) + return c1s, c2s, c3s, last_c1, last_c2, last_c3 + + + def backward(self, grad_c1s, grad_c2s, grad_c3s, grad_last_c1, grad_last_c2, grad_last_c3): + bidir = 2 if self.bidirectional else 1 + u, c1_init, c2_init, c3_init = self.saved_tensors + c1s, c2s, c3s = self.intermediate_c1s, self.intermediate_c2s, self.intermediate_c3s + length, batch = u.size(0), u.size(1) + dim = self.d_out + ncols = batch*dim*bidir + thread_per_block = min(512, ncols) + num_block = (ncols-1)//thread_per_block+1 + + if c1_init is None: + assert False + # init_ = x.new(ncols).zero_() if init is None else init + grad_u = u.new(*u.size()) + grad_init_c1 = u.new(batch, dim*bidir) + grad_init_c2 = u.new(batch, dim*bidir) + grad_init_c3 = u.new(batch, dim*bidir) + stream, _, bwd_func = self.get_functions() + FUNC = bwd_func + + FUNC(args=[ + u.contiguous().data_ptr(), + c1_init.contiguous().data_ptr(), + c2_init.contiguous().data_ptr(), + c3_init.contiguous().data_ptr(), + c1s.data_ptr(), + c2s.data_ptr(), + c3s.data_ptr(), + grad_c1s.data_ptr(), + grad_c2s.data_ptr(), + grad_c3s.data_ptr(), + grad_last_c1.contiguous().data_ptr(), + grad_last_c2.contiguous().data_ptr(), + grad_last_c3.contiguous().data_ptr(), + np.int32(length), + np.int32(batch), + np.int32(dim), + np.int32(self.k), + grad_u.data_ptr(), + grad_init_c1.data_ptr(), + grad_init_c2.data_ptr(), + grad_init_c3.data_ptr(), + np.int32(self.semiring.type)], + block = (thread_per_block,1,1), grid = (num_block,1,1), + stream=stream + ) + + return grad_u, grad_init_c1, grad_init_c2, grad_init_c3 + + +class RRNN_4gram_Compute_GPU(Function): + + _RRNN_PROG = Program((UTIL + FOURGRAM_RRNN).encode("utf-8"), "rrnn_prog.cu".encode()) + _RRNN_PTX = _RRNN_PROG.compile() + _DEVICE2FUNC = {} + + + def __init__(self, d_out, k, semiring, bidirectional=False): + super(RRNN_4gram_Compute_GPU, self).__init__() + self.semiring = semiring + self.d_out = d_out + self.k = k + self.bidirectional = bidirectional + assert not bidirectional + + + def compile_functions(self): + device = torch.cuda.current_device() + print ("RRNN loaded for gpu {}".format(device)) + mod = function.Module() + mod.load(bytes(self._RRNN_PTX.encode())) + + if self.semiring.type == 0: + fwd_func = mod.get_function("rrnn_fwd") + bwd_func = mod.get_function("rrnn_bwd") + Stream = namedtuple("Stream", ["ptr"]) + current_stream = Stream(ptr=torch.cuda.current_stream().cuda_stream) + self._DEVICE2FUNC[device] = ( + current_stream, fwd_func, bwd_func, + ) + return current_stream, fwd_func, bwd_func + else: + assert False, "other semirings are not currently implemented." + + def get_functions(self): + res = self._DEVICE2FUNC.get(torch.cuda.current_device(), None) + return res if res else self.compile_functions() + + + def forward(self, u, c1_init=None, c2_init=None, c3_init=None, c4_init=None): + bidir = 2 if self.bidirectional else 1 + assert u.size(-1) == self.k + length, batch = u.size(0), u.size(1) + dim = self.d_out + ncols = batch*dim*bidir + thread_per_block = min(512, ncols) + num_block = (ncols-1)//thread_per_block+1 + if c1_init is None: + assert False + + size = (length, batch, bidir, dim) + c1s = u.new(*size) + c2s = u.new(*size) + c3s = u.new(*size) + c4s = u.new(*size) + stream, fwd_func, _ = self.get_functions() + FUNC = fwd_func + FUNC(args=[ + u.contiguous().data_ptr(), + c1_init.contiguous().data_ptr(), + c2_init.contiguous().data_ptr(), + c3_init.contiguous().data_ptr(), + c4_init.contiguous().data_ptr(), + np.int32(length), + np.int32(batch), + np.int32(dim), + np.int32(self.k), + c1s.data_ptr(), + c2s.data_ptr(), + c3s.data_ptr(), + c4s.data_ptr(), + np.int32(self.semiring.type)], + block = (thread_per_block,1,1), grid = (num_block,1,1), + stream=stream + ) + self.save_for_backward(u, c1_init, c2_init, c3_init, c4_init) + self.intermediate_c1s, self.intermediate_c2s = c1s, c2s + self.intermediate_c3s, self.intermediate_c4s = c3s, c4s + if self.bidirectional: + assert False, "bidirectionality isn't implemented yet" + else: + last_c1 = c1s[-1,...].view(batch, -1) + last_c2 = c2s[-1,...].view(batch, -1) + last_c3 = c3s[-1,...].view(batch, -1) + last_c4 = c4s[-1,...].view(batch, -1) + return c1s, c2s, c3s, c4s, last_c1, last_c2, last_c3, last_c4 + + + def backward(self, grad_c1s, grad_c2s, grad_c3s, grad_c4s, grad_last_c1, grad_last_c2, grad_last_c3, grad_last_c4): + bidir = 2 if self.bidirectional else 1 + u, c1_init, c2_init, c3_init, c4_init = self.saved_tensors + c1s, c2s, c3s, c4s = self.intermediate_c1s, self.intermediate_c2s, self.intermediate_c3s, self.intermediate_c4s + length, batch = u.size(0), u.size(1) + dim = self.d_out + ncols = batch*dim*bidir + thread_per_block = min(512, ncols) + num_block = (ncols-1)//thread_per_block+1 + + if c1_init is None: + assert False + # init_ = x.new(ncols).zero_() if init is None else init + grad_u = u.new(*u.size()) + grad_init_c1 = u.new(batch, dim*bidir) + grad_init_c2 = u.new(batch, dim*bidir) + grad_init_c3 = u.new(batch, dim*bidir) + grad_init_c4 = u.new(batch, dim*bidir) + stream, _, bwd_func = self.get_functions() + FUNC = bwd_func + + FUNC(args=[ + u.contiguous().data_ptr(), + c1_init.contiguous().data_ptr(), + c2_init.contiguous().data_ptr(), + c3_init.contiguous().data_ptr(), + c4_init.contiguous().data_ptr(), + c1s.data_ptr(), + c2s.data_ptr(), + c3s.data_ptr(), + c4s.data_ptr(), + grad_c1s.data_ptr(), + grad_c2s.data_ptr(), + grad_c3s.data_ptr(), + grad_c4s.data_ptr(), + grad_last_c1.contiguous().data_ptr(), + grad_last_c2.contiguous().data_ptr(), + grad_last_c3.contiguous().data_ptr(), + grad_last_c4.contiguous().data_ptr(), + np.int32(length), + np.int32(batch), + np.int32(dim), + np.int32(self.k), + grad_u.data_ptr(), + grad_init_c1.data_ptr(), + grad_init_c2.data_ptr(), + grad_init_c3.data_ptr(), + grad_init_c4.data_ptr(), + np.int32(self.semiring.type)], + block = (thread_per_block,1,1), grid = (num_block,1,1), + stream=stream + ) + + return grad_u, grad_init_c1, grad_init_c2, grad_init_c3, grad_init_c4 diff --git a/to_install.txt b/to_install.txt new file mode 100644 index 0000000..e0ef78f --- /dev/null +++ b/to_install.txt @@ -0,0 +1,3 @@ +conda create --name rational-recurrences python=3.6 pip +source activate rational-recurrences +pip install -r requirements.txt