mBART support for run_summarization.py (#15125)
* Update run_summarization.py * Fixed languages and added missing code * fixed obj, docs, removed source_lang and target_lang * make style, run_summarization.py reformatted
This commit is contained in:
@@ -37,6 +37,10 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
MBart50Tokenizer,
|
||||
MBart50TokenizerFast,
|
||||
MBartTokenizer,
|
||||
MBartTokenizerFast,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
set_seed,
|
||||
@@ -64,6 +68,9 @@ except (LookupError, OSError):
|
||||
with FileLock(".lock") as lock:
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
# A list of all multilingual tokenizer which require lang attribute.
|
||||
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
@@ -114,6 +121,8 @@ class DataTrainingArguments:
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
lang: str = field(default=None, metadata={"help": "Language id for summarization."})
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
@@ -217,12 +226,24 @@ class DataTrainingArguments:
|
||||
},
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
)
|
||||
|
||||
forced_bos_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The token to force as the first generated token after the decoder_start_token_id."
|
||||
"Useful for multilingual models like mBART where the first generated token"
|
||||
"needs to be the target language token (Usually it is the target language token)"
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
elif self.lang is None:
|
||||
raise ValueError("Need to specify the language.")
|
||||
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
@@ -370,6 +391,12 @@ def main():
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||
if isinstance(tokenizer, MBartTokenizer):
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang]
|
||||
else:
|
||||
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang)
|
||||
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
@@ -406,6 +433,21 @@ def main():
|
||||
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
||||
return
|
||||
|
||||
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
|
||||
assert (
|
||||
data_args.lang is not None
|
||||
), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"
|
||||
|
||||
tokenizer.src_lang = data_args.lang
|
||||
tokenizer.tgt_lang = data_args.lang
|
||||
|
||||
# For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
|
||||
# as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
|
||||
forced_bos_token_id = (
|
||||
tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
|
||||
)
|
||||
model.config.forced_bos_token_id = forced_bos_token_id
|
||||
|
||||
# Get the column names for input/target.
|
||||
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
|
||||
if data_args.text_column is None:
|
||||
@@ -436,14 +478,16 @@ def main():
|
||||
)
|
||||
|
||||
def preprocess_function(examples):
|
||||
|
||||
# remove pairs where at least one record is None
|
||||
|
||||
inputs, targets = [], []
|
||||
for i in range(len(examples[text_column])):
|
||||
if examples[text_column][i] is not None and examples[summary_column][i] is not None:
|
||||
inputs.append(examples[text_column][i])
|
||||
targets.append(examples[summary_column][i])
|
||||
|
||||
inputs = examples[text_column]
|
||||
targets = examples[summary_column]
|
||||
inputs = [prefix + inp for inp in inputs]
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
|
||||
@@ -637,6 +681,9 @@ def main():
|
||||
else:
|
||||
kwargs["dataset"] = data_args.dataset_name
|
||||
|
||||
if data_args.lang is not None:
|
||||
kwargs["language"] = data_args.lang
|
||||
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**kwargs)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user