diff --git a/main.py b/main.py index 8626170..98fd50d 100644 --- a/main.py +++ b/main.py @@ -64,6 +64,9 @@ help='optimizer to use (sgd, adam)') parser.add_argument('--when', nargs="+", type=int, default=[-1], help='When (which epochs) to divide the learning rate by 10 - accepts multiple') +parser.add_argument('--hebbian_softmax', action='store_true', + help='Use Hebbian Softmax training for word-level LMs') + args = parser.parse_args() args.tied = True @@ -207,6 +210,19 @@ def train(): if args.clip: torch.nn.utils.clip_grad_norm_(params, args.clip) optimizer.step() + if args.hebbian_softmax: + with torch.no_grad(): + unique_tokens, tokens_counts = np.unique(targets, return_counts=True) + for (token, token_count) in zip(unique_tokens, tokens_counts): + token = int(token) + if seen_counter[token] <= smoothing_limit: + seen_counter[token] += token_count + lamb = np.maximum(1 / seen_counter[token], min_discount) + idx = (targets == token).nonzero().squeeze(1) + h_bar = torch.index_select(output, 0, idx).mean(0) + model.decoder.weight[token] = model.decoder.weight[token].data * (1 - lamb) \ + + h_bar * lamb + total_loss += raw_loss.data optimizer.param_groups[0]['lr'] = lr2 if batch % args.log_interval == 0 and batch > 0: @@ -222,6 +238,17 @@ def train(): batch += 1 i += seq_len +# Hebbian Softmax hyperparameters +if args.hebbian_softmax: + if ntokens < 1000: + print('ntokens suggests char-level LM. Setting Hebbian Softmax to False because it works with word-level LMs') + args.hebbian_softmax = False + else: + seen_counter = np.zeros(ntokens) # c + _, tokens_counts = np.unique(train_data, return_counts=True) + min_discount = 1 / min(tokens_counts) # gamma + smoothing_limit = min(tokens_counts) * args.epochs # T + # Loop over epochs. lr = args.lr best_val_loss = []