Add beam search

This commit is contained in:
Rémi Louf
2019-10-31 17:59:16 +01:00
committed by Julien Chaumond
parent 1c71ecc880
commit 9660ba1cbd
6 changed files with 594 additions and 784 deletions

View File

@@ -25,9 +25,8 @@ class CNNDailyMailDataset(Dataset):
[2] https://github.com/abisee/cnn-dailymail/
"""
def __init__(self, tokenizer, prefix="train", data_dir=""):
def __init__(self, data_dir="", prefix="train"):
assert os.path.isdir(data_dir)
self.tokenizer = tokenizer
# We initialize the class by listing all the files that contain
# stories and summaries. Files are not read in memory given
@@ -104,31 +103,30 @@ def _add_missing_period(line):
# --------------------------
def fit_to_block_size(sequence, block_size, pad_token):
def fit_to_block_size(sequence, block_size, pad_token_id):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter than the block size we pad it with -1 ids
which correspond to padding tokens.
If the sequence is shorter we append padding token to the right of the sequence.
"""
if len(sequence) > block_size:
return sequence[:block_size]
else:
sequence.extend([pad_token] * (block_size - len(sequence)))
sequence.extend([pad_token_id] * (block_size - len(sequence)))
return sequence
def build_lm_labels(sequence, pad_token):
""" Padding token, encoded as 0, are represented by the value -1 so they
def build_lm_labels(sequence, pad_token_id):
""" Padding token are replaced by the value -1 so they
are not taken into account in the loss computation. """
padded = sequence.clone()
padded[padded == pad_token] = -1
padded[padded == pad_token_id] = -1
return padded
def build_mask(sequence, pad_token):
def build_mask(sequence, pad_token_id):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask = torch.ones_like(sequence)
idx_pad_tokens = sequence == pad_token
idx_pad_tokens = sequence == pad_token_id
mask[idx_pad_tokens] = 0
return mask