Fix memory regression in Seq2Seq example (#9713)
* Fix memory regression in Seq2Seq example * Fix test and properly deal with -100 * Easier condition with device safety * Patch for MBartTokenzierFast
This commit is contained in:
@@ -26,6 +26,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
MBartTokenizer,
|
||||
MBartTokenizerFast,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
set_seed,
|
||||
@@ -220,11 +221,14 @@ def main():
|
||||
data_args.eval_beams = model.config.num_beams
|
||||
|
||||
# set decoder_start_token_id for MBart
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||
assert (
|
||||
data_args.tgt_lang is not None and data_args.src_lang is not None
|
||||
), "mBart requires --tgt_lang and --src_lang"
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
||||
if isinstance(tokenizer, MBartTokenizer):
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
||||
else:
|
||||
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.tgt_lang)
|
||||
|
||||
if model_args.freeze_embeds:
|
||||
freeze_embeds(model)
|
||||
@@ -284,7 +288,9 @@ def main():
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
||||
data_collator=Seq2SeqDataCollator(
|
||||
tokenizer, data_args, model.config.decoder_start_token_id, training_args.tpu_num_cores
|
||||
),
|
||||
compute_metrics=compute_metrics_fn,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
@@ -33,8 +33,9 @@ from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
from sentence_splitter import add_newline_to_end_of_each_sentence
|
||||
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer
|
||||
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||
|
||||
|
||||
try:
|
||||
@@ -274,9 +275,10 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
|
||||
|
||||
class Seq2SeqDataCollator:
|
||||
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
||||
def __init__(self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
assert (
|
||||
self.pad_token_id is not None
|
||||
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
|
||||
@@ -304,9 +306,15 @@ class Seq2SeqDataCollator:
|
||||
labels = trim_batch(labels, self.pad_token_id)
|
||||
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
|
||||
|
||||
if isinstance(self.tokenizer, T5Tokenizer):
|
||||
decoder_input_ids = self._shift_right_t5(labels)
|
||||
else:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id, self.decoder_start_token_id)
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
return batch
|
||||
|
||||
Reference in New Issue
Block a user