delegate the padding with special tokens to the tokenizer
This commit is contained in:
@@ -53,20 +53,14 @@ def set_seed(args):
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
""" Abstracts a dataset used to train seq2seq models.
|
||||
|
||||
A seq2seq dataset consists of two files:
|
||||
- The source file that contains the source sequences, one line per sequence;
|
||||
- The target file contains the target sequences, one line per sequence.
|
||||
|
||||
The matching betwen source and target sequences is made on the basis of line numbers.
|
||||
""" Abstracts the dataset used to train seq2seq models.
|
||||
|
||||
CNN/Daily News:
|
||||
|
||||
The CNN/Daily News raw datasets are downloaded from [1]. They consist in stories stored
|
||||
in different files where the summary sentences are indicated by the special `@highlight` token.
|
||||
To process the data, untar both datasets in the same folder, and path the path to this
|
||||
folder as the "train_data_file" argument. The formatting code was inspired by [2].
|
||||
To process the data, untar both datasets in the same folder, and pass the path to this
|
||||
folder as the "data_dir argument. The formatting code was inspired by [2].
|
||||
|
||||
[1] https://cs.nyu.edu/~kcho/
|
||||
[2] https://github.com/abisee/cnn-dailymail/
|
||||
@@ -82,9 +76,8 @@ class TextDataset(Dataset):
|
||||
self.examples = pickle.load(source)
|
||||
return
|
||||
|
||||
logger.info("Creating features from dataset at %s", directory)
|
||||
logger.info("Creating features from dataset at %s", data_dir)
|
||||
|
||||
# we need to iterate over both the cnn and the dailymail dataset
|
||||
datasets = ['cnn', 'dailymail']
|
||||
for dataset in datasets:
|
||||
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
||||
@@ -102,9 +95,10 @@ class TextDataset(Dataset):
|
||||
except IndexError:
|
||||
continue
|
||||
|
||||
src_sequence = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
|
||||
tgt_sequence = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
|
||||
example = _truncate_and_concatenate(src_sequence, tgt_sequence, blocksize)
|
||||
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
|
||||
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
|
||||
story_seq, summary_seq = _fit_to_block_size(story, summary, blocksize)
|
||||
example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
|
||||
self.examples.append(example)
|
||||
|
||||
logger.info("Saving features into cache file %s", cached_features_file)
|
||||
@@ -158,15 +152,13 @@ def _add_missing_period(line):
|
||||
return line + " ."
|
||||
|
||||
|
||||
def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size):
|
||||
def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
|
||||
""" Concatenate the sequences and adapt their lengths to the block size.
|
||||
|
||||
Following [1] we perform the following transformations:
|
||||
- Add an [CLS] token at the beginning of the source sequence;
|
||||
- Add an [EOS] token at the end of the source and target sequences;
|
||||
- Concatenate the source and target + tokens sequence. If the concatenated sequence is
|
||||
longer than 512 we follow the 75%/25% rule in [1]: limit the source sequence's length to 384
|
||||
and the target sequence's length to 128.
|
||||
Following [1] we truncate the source and target + tokens sequences so they fit
|
||||
in the block size. If the concatenated sequence is longer than 512 we follow
|
||||
the 75%/25% rule in [1]: limit the source sequence's length to 384 and the
|
||||
target sequence's length to 128.
|
||||
|
||||
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
|
||||
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
|
||||
@@ -176,22 +168,21 @@ def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size):
|
||||
|
||||
# we dump the examples that are too small to fit in the block size for the
|
||||
# sake of simplicity. You can modify this by adding model-specific padding.
|
||||
if len(src_tokens) + len(src_tokens) + 3 < block_size:
|
||||
if len(src_sequence) + len(src_sequence) + 3 < block_size:
|
||||
return None
|
||||
|
||||
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
|
||||
if len(src_tokens) > SRC_MAX_LENGTH
|
||||
if len(tgt_tokens) > TGT_MAX_LENGTH:
|
||||
src_tokens = src_tokens[:SRC_MAX_LENGTH]
|
||||
tgt_tokens = tgt_tokens[:TGT_MAX_LENGTH]
|
||||
if len(src_sequence) > SRC_MAX_LENGTH
|
||||
if len(tgt_sequence) > TGT_MAX_LENGTH:
|
||||
src_sequence = src_sequence[:SRC_MAX_LENGTH]
|
||||
tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH]
|
||||
else:
|
||||
src_tokens = src_tokens[block_size - len(tgt_tokens) - 3]
|
||||
src_sequence = src_sequence[block_size - len(tgt_sequence) - 3]
|
||||
else:
|
||||
if len(tgt_tokens) > TGT_MAX_LENGTH:
|
||||
tgt_tokens = tgt_tokens[block_size - len(src_tokens) - 3]
|
||||
tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3]
|
||||
|
||||
# I add the special tokens manually, but this should be done by the tokenizer. That's the next step.
|
||||
return ["[CLS]"] + src_tokens + ["[EOS]"] + tgt_tokens + ["[EOS]"]
|
||||
return src_sequence, tgt_sequence
|
||||
|
||||
|
||||
|
||||
@@ -250,4 +241,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user