Fix memory regression in Seq2Seq example (#9713)

* Fix memory regression in Seq2Seq example

* Fix test and properly deal with -100

* Easier condition with device safety

* Patch for MBartTokenzierFast
This commit is contained in:
Sylvain Gugger
2021-01-21 12:05:46 -05:00
committed by GitHub
parent a7dabfb3d1
commit 5f80c15ef5
5 changed files with 43 additions and 16 deletions

View File

@@ -26,6 +26,7 @@ from transformers import (
AutoTokenizer,
HfArgumentParser,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
set_seed,
@@ -220,11 +221,14 @@ def main():
data_args.eval_beams = model.config.num_beams
# set decoder_start_token_id for MBart
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
assert (
data_args.tgt_lang is not None and data_args.src_lang is not None
), "mBart requires --tgt_lang and --src_lang"
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.tgt_lang)
if model_args.freeze_embeds:
freeze_embeds(model)
@@ -284,7 +288,9 @@ def main():
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
data_collator=Seq2SeqDataCollator(
tokenizer, data_args, model.config.decoder_start_token_id, training_args.tpu_num_cores
),
compute_metrics=compute_metrics_fn,
tokenizer=tokenizer,
)

View File

@@ -33,8 +33,9 @@ from torch import nn
from torch.utils.data import Dataset, Sampler
from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
from transformers.file_utils import cached_property
from transformers.models.bart.modeling_bart import shift_tokens_right
try:
@@ -274,9 +275,10 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
def __init__(self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None):
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
self.decoder_start_token_id = decoder_start_token_id
assert (
self.pad_token_id is not None
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
@@ -304,9 +306,15 @@ class Seq2SeqDataCollator:
labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels)
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id, self.decoder_start_token_id)
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels,
}
return batch