From 9a94bb8e218033cffa1ef380010b528410ba3ca7 Mon Sep 17 00:00:00 2001 From: Edoardo Federici <49756048+banda-larga@users.noreply.github.com> Date: Wed, 12 Jan 2022 22:39:33 +0100 Subject: [PATCH] mBART support for run_summarization.py (#15125) * Update run_summarization.py * Fixed languages and added missing code * fixed obj, docs, removed source_lang and target_lang * make style, run_summarization.py reformatted --- .../summarization/run_summarization.py | 51 ++++++++++++++++++- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index c2d0ff8795..4e717d8815 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -37,6 +37,10 @@ from transformers import ( AutoTokenizer, DataCollatorForSeq2Seq, HfArgumentParser, + MBart50Tokenizer, + MBart50TokenizerFast, + MBartTokenizer, + MBartTokenizerFast, Seq2SeqTrainer, Seq2SeqTrainingArguments, set_seed, @@ -64,6 +68,9 @@ except (LookupError, OSError): with FileLock(".lock") as lock: nltk.download("punkt", quiet=True) +# A list of all multilingual tokenizer which require lang attribute. +MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast] + @dataclass class ModelArguments: @@ -114,6 +121,8 @@ class DataTrainingArguments: Arguments pertaining to what data we are going to input our model for training and eval. """ + lang: str = field(default=None, metadata={"help": "Language id for summarization."}) + dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} ) @@ -217,12 +226,24 @@ class DataTrainingArguments: }, ) source_prefix: Optional[str] = field( - default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + + forced_bos_token: Optional[str] = field( + default=None, + metadata={ + "help": "The token to force as the first generated token after the decoder_start_token_id." + "Useful for multilingual models like mBART where the first generated token" + "needs to be the target language token (Usually it is the target language token)" + }, ) def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: raise ValueError("Need either a dataset name or a training/validation file.") + elif self.lang is None: + raise ValueError("Need to specify the language.") + else: if self.train_file is not None: extension = self.train_file.split(".")[-1] @@ -370,6 +391,12 @@ def main(): model.resize_token_embeddings(len(tokenizer)) + if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): + if isinstance(tokenizer, MBartTokenizer): + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang] + else: + model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang) + if model.config.decoder_start_token_id is None: raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") @@ -406,6 +433,21 @@ def main(): logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") return + if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): + assert ( + data_args.lang is not None + ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument" + + tokenizer.src_lang = data_args.lang + tokenizer.tgt_lang = data_args.lang + + # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token + # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument. + forced_bos_token_id = ( + tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None + ) + model.config.forced_bos_token_id = forced_bos_token_id + # Get the column names for input/target. dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) if data_args.text_column is None: @@ -436,14 +478,16 @@ def main(): ) def preprocess_function(examples): - # remove pairs where at least one record is None + inputs, targets = [], [] for i in range(len(examples[text_column])): if examples[text_column][i] is not None and examples[summary_column][i] is not None: inputs.append(examples[text_column][i]) targets.append(examples[summary_column][i]) + inputs = examples[text_column] + targets = examples[summary_column] inputs = [prefix + inp for inp in inputs] model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) @@ -637,6 +681,9 @@ def main(): else: kwargs["dataset"] = data_args.dataset_name + if data_args.lang is not None: + kwargs["language"] = data_args.lang + if training_args.push_to_hub: trainer.push_to_hub(**kwargs) else: