correct the truncation and padding of dataset
This commit is contained in:
@@ -104,9 +104,11 @@ class TextDataset(Dataset):
|
|||||||
except IndexError: # skip ill-formed stories
|
except IndexError: # skip ill-formed stories
|
||||||
continue
|
continue
|
||||||
|
|
||||||
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
|
|
||||||
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
|
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
|
||||||
story_seq, summary_seq = _fit_to_block_size(story, summary, block_size)
|
summary_seq = _fit_to_block_size(summary, block_size)
|
||||||
|
|
||||||
|
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
|
||||||
|
story_seq = _fit_to_block_size(story, block_size)
|
||||||
|
|
||||||
self.examples.append(
|
self.examples.append(
|
||||||
tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
|
tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
|
||||||
@@ -170,45 +172,15 @@ def _add_missing_period(line):
|
|||||||
return line + "."
|
return line + "."
|
||||||
|
|
||||||
|
|
||||||
def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
|
def _fit_to_block_size(sequence, block_size):
|
||||||
""" Adapt the source and target sequences' lengths to the block size.
|
""" Adapt the source and target sequences' lengths to the block size.
|
||||||
|
If the sequence is shorter than the block size we pad it with -1 ids
|
||||||
If the concatenated sequence (source + target + 3 special tokens) would be
|
which correspond to padding tokens.
|
||||||
longer than the block size we use the 75% / 25% rule followed in [1]. For a
|
|
||||||
block size of 512 this means limiting the source sequence's length to 384
|
|
||||||
and the target sequence's length to 128.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
src_sequence (list): a list of ids that maps to the tokens of the
|
|
||||||
source sequence.
|
|
||||||
tgt_sequence (list): a list of ids that maps to the tokens of the
|
|
||||||
target sequence.
|
|
||||||
block_size (int): the model's block size.
|
|
||||||
|
|
||||||
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
|
|
||||||
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
|
|
||||||
"""
|
"""
|
||||||
SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token
|
if len(sequence) > block_size:
|
||||||
TGT_MAX_LENGTH = block_size - (SRC_MAX_LENGTH + 2) - 1 # EOS token
|
return sequence[:block_size]
|
||||||
|
|
||||||
# We dump the examples that are too small to fit in the block size for the
|
|
||||||
# sake of simplicity. You can modify this by adding model-specific padding.
|
|
||||||
if len(src_sequence) + len(tgt_sequence) + 3 < block_size:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if len(src_sequence) > SRC_MAX_LENGTH:
|
|
||||||
if len(tgt_sequence) > TGT_MAX_LENGTH:
|
|
||||||
src_sequence = src_sequence[:SRC_MAX_LENGTH]
|
|
||||||
tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH]
|
|
||||||
else:
|
else:
|
||||||
remain_size = block_size - len(tgt_sequence) - 3
|
return sequence.extend([-1] * [block_size - len(sequence)])
|
||||||
src_sequence = src_sequence[:remain_size]
|
|
||||||
else:
|
|
||||||
if len(tgt_sequence) > TGT_MAX_LENGTH:
|
|
||||||
remain_size = block_size - len(src_sequence) - 3
|
|
||||||
tgt_sequence = tgt_sequence[:remain_size]
|
|
||||||
|
|
||||||
return src_sequence, tgt_sequence
|
|
||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer):
|
def load_and_cache_examples(args, tokenizer):
|
||||||
|
|||||||
Reference in New Issue
Block a user