Fix memory regression in Seq2Seq example (#9713)

* Fix memory regression in Seq2Seq example

* Fix test and properly deal with -100

* Easier condition with device safety

* Patch for MBartTokenzierFast
This commit is contained in:
Sylvain Gugger
2021-01-21 12:05:46 -05:00
committed by GitHub
parent a7dabfb3d1
commit 5f80c15ef5
5 changed files with 43 additions and 16 deletions

View File

@@ -33,8 +33,9 @@ from torch import nn
from torch.utils.data import Dataset, Sampler
from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
from transformers.file_utils import cached_property
from transformers.models.bart.modeling_bart import shift_tokens_right
try:
@@ -274,9 +275,10 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
def __init__(self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None):
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
self.decoder_start_token_id = decoder_start_token_id
assert (
self.pad_token_id is not None
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
@@ -304,9 +306,15 @@ class Seq2SeqDataCollator:
labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels)
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id, self.decoder_start_token_id)
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels,
}
return batch