From c161dd56df34b036d95f0ff772782f510a4e3235 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 9 Apr 2021 23:58:42 +0530 Subject: [PATCH] [examples/translation] support mBART-50 and M2M100 fine-tuning (#11170) * keep a list of multilingual tokenizers * add forced_bos_token argument --- examples/seq2seq/run_translation.py | 37 ++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/examples/seq2seq/run_translation.py b/examples/seq2seq/run_translation.py index dab84d5915..a41da4e0ab 100755 --- a/examples/seq2seq/run_translation.py +++ b/examples/seq2seq/run_translation.py @@ -34,6 +34,9 @@ from transformers import ( AutoTokenizer, DataCollatorForSeq2Seq, HfArgumentParser, + M2M100Tokenizer, + MBart50Tokenizer, + MBart50TokenizerFast, MBartTokenizer, MBartTokenizerFast, Seq2SeqTrainer, @@ -50,6 +53,9 @@ check_min_version("4.6.0.dev0") logger = logging.getLogger(__name__) +# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes. +MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer] + @dataclass class ModelArguments: @@ -191,6 +197,14 @@ class DataTrainingArguments: source_prefix: Optional[str] = field( default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} ) + forced_bos_token: Optional[str] = field( + default=None, + metadata={ + "help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`." + "Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token " + "needs to be the target language token.(Usually it is the target language token)" + }, + ) def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: @@ -325,9 +339,6 @@ def main(): # Set decoder_start_token_id if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): - assert ( - data_args.target_lang is not None and data_args.source_lang is not None - ), "mBart requires --target_lang and --source_lang" if isinstance(tokenizer, MBartTokenizer): model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang] else: @@ -352,11 +363,21 @@ def main(): # For translation we set the codes of our source and target languages (only useful for mBART, the others will # ignore those attributes). - if isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): - if data_args.source_lang is not None: - tokenizer.src_lang = data_args.source_lang - if data_args.target_lang is not None: - tokenizer.tgt_lang = data_args.target_lang + if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): + assert data_args.target_lang is not None and data_args.source_lang is not None, ( + f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --source_lang and " + "--target_lang arguments." + ) + + tokenizer.src_lang = data_args.source_lang + tokenizer.tgt_lang = data_args.target_lang + + # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token + # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument. + forced_bos_token_id = ( + tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None + ) + model.config.foced_bos_token_id = forced_bos_token_id # Get the language codes for input/target. source_lang = data_args.source_lang.split("_")[0]