From c1bc709c3545fbafd7d7d9da01ba89d35aff6a79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 17 Oct 2019 10:41:53 +0200 Subject: [PATCH] correct the truncation and padding of dataset --- examples/run_seq2seq_finetuning.py | 48 +++++++----------------------- 1 file changed, 10 insertions(+), 38 deletions(-) diff --git a/examples/run_seq2seq_finetuning.py b/examples/run_seq2seq_finetuning.py index 2e8d0aa250..32f1782cab 100644 --- a/examples/run_seq2seq_finetuning.py +++ b/examples/run_seq2seq_finetuning.py @@ -104,9 +104,11 @@ class TextDataset(Dataset): except IndexError: # skip ill-formed stories continue - story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story)) summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary)) - story_seq, summary_seq = _fit_to_block_size(story, summary, block_size) + summary_seq = _fit_to_block_size(summary, block_size) + + story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story)) + story_seq = _fit_to_block_size(story, block_size) self.examples.append( tokenizer.add_special_token_sequence_pair(story_seq, summary_seq) @@ -170,45 +172,15 @@ def _add_missing_period(line): return line + "." -def _fit_to_block_size(src_sequence, tgt_sequence, block_size): +def _fit_to_block_size(sequence, block_size): """ Adapt the source and target sequences' lengths to the block size. - - If the concatenated sequence (source + target + 3 special tokens) would be - longer than the block size we use the 75% / 25% rule followed in [1]. For a - block size of 512 this means limiting the source sequence's length to 384 - and the target sequence's length to 128. - - Attributes: - src_sequence (list): a list of ids that maps to the tokens of the - source sequence. - tgt_sequence (list): a list of ids that maps to the tokens of the - target sequence. - block_size (int): the model's block size. - - [1] Dong, Li, et al. "Unified Language Model Pre-training for Natural - Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019). + If the sequence is shorter than the block size we pad it with -1 ids + which correspond to padding tokens. """ - SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token - TGT_MAX_LENGTH = block_size - (SRC_MAX_LENGTH + 2) - 1 # EOS token - - # We dump the examples that are too small to fit in the block size for the - # sake of simplicity. You can modify this by adding model-specific padding. - if len(src_sequence) + len(tgt_sequence) + 3 < block_size: - return None - - if len(src_sequence) > SRC_MAX_LENGTH: - if len(tgt_sequence) > TGT_MAX_LENGTH: - src_sequence = src_sequence[:SRC_MAX_LENGTH] - tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH] - else: - remain_size = block_size - len(tgt_sequence) - 3 - src_sequence = src_sequence[:remain_size] + if len(sequence) > block_size: + return sequence[:block_size] else: - if len(tgt_sequence) > TGT_MAX_LENGTH: - remain_size = block_size - len(src_sequence) - 3 - tgt_sequence = tgt_sequence[:remain_size] - - return src_sequence, tgt_sequence + return sequence.extend([-1] * [block_size - len(sequence)]) def load_and_cache_examples(args, tokenizer):