mBART support for run_summarization.py (#15125)

* Update run_summarization.py

* Fixed languages and added missing code

* fixed obj, docs, removed source_lang and target_lang

* make style, run_summarization.py reformatted
This commit is contained in:
Edoardo Federici
2022-01-12 22:39:33 +01:00
committed by GitHub
parent 97f3beed36
commit 9a94bb8e21

View File

@@ -37,6 +37,10 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
HfArgumentParser, HfArgumentParser,
MBart50Tokenizer,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer, Seq2SeqTrainer,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
set_seed, set_seed,
@@ -64,6 +68,9 @@ except (LookupError, OSError):
with FileLock(".lock") as lock: with FileLock(".lock") as lock:
nltk.download("punkt", quiet=True) nltk.download("punkt", quiet=True)
# A list of all multilingual tokenizer which require lang attribute.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast]
@dataclass @dataclass
class ModelArguments: class ModelArguments:
@@ -114,6 +121,8 @@ class DataTrainingArguments:
Arguments pertaining to what data we are going to input our model for training and eval. Arguments pertaining to what data we are going to input our model for training and eval.
""" """
lang: str = field(default=None, metadata={"help": "Language id for summarization."})
dataset_name: Optional[str] = field( dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
) )
@@ -217,12 +226,24 @@ class DataTrainingArguments:
}, },
) )
source_prefix: Optional[str] = field( source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": "The token to force as the first generated token after the decoder_start_token_id."
"Useful for multilingual models like mBART where the first generated token"
"needs to be the target language token (Usually it is the target language token)"
},
) )
def __post_init__(self): def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None: if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.") raise ValueError("Need either a dataset name or a training/validation file.")
elif self.lang is None:
raise ValueError("Need to specify the language.")
else: else:
if self.train_file is not None: if self.train_file is not None:
extension = self.train_file.split(".")[-1] extension = self.train_file.split(".")[-1]
@@ -370,6 +391,12 @@ def main():
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang)
if model.config.decoder_start_token_id is None: if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
@@ -406,6 +433,21 @@ def main():
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return return
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
assert (
data_args.lang is not None
), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"
tokenizer.src_lang = data_args.lang
tokenizer.tgt_lang = data_args.lang
# For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
# as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
forced_bos_token_id = (
tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
)
model.config.forced_bos_token_id = forced_bos_token_id
# Get the column names for input/target. # Get the column names for input/target.
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
if data_args.text_column is None: if data_args.text_column is None:
@@ -436,14 +478,16 @@ def main():
) )
def preprocess_function(examples): def preprocess_function(examples):
# remove pairs where at least one record is None # remove pairs where at least one record is None
inputs, targets = [], [] inputs, targets = [], []
for i in range(len(examples[text_column])): for i in range(len(examples[text_column])):
if examples[text_column][i] is not None and examples[summary_column][i] is not None: if examples[text_column][i] is not None and examples[summary_column][i] is not None:
inputs.append(examples[text_column][i]) inputs.append(examples[text_column][i])
targets.append(examples[summary_column][i]) targets.append(examples[summary_column][i])
inputs = examples[text_column]
targets = examples[summary_column]
inputs = [prefix + inp for inp in inputs] inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
@@ -637,6 +681,9 @@ def main():
else: else:
kwargs["dataset"] = data_args.dataset_name kwargs["dataset"] = data_args.dataset_name
if data_args.lang is not None:
kwargs["language"] = data_args.lang
if training_args.push_to_hub: if training_args.push_to_hub:
trainer.push_to_hub(**kwargs) trainer.push_to_hub(**kwargs)
else: else: