mBART support for run_summarization.py (#15125)
* Update run_summarization.py * Fixed languages and added missing code * fixed obj, docs, removed source_lang and target_lang * make style, run_summarization.py reformatted
This commit is contained in:
@@ -37,6 +37,10 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
|
MBart50Tokenizer,
|
||||||
|
MBart50TokenizerFast,
|
||||||
|
MBartTokenizer,
|
||||||
|
MBartTokenizerFast,
|
||||||
Seq2SeqTrainer,
|
Seq2SeqTrainer,
|
||||||
Seq2SeqTrainingArguments,
|
Seq2SeqTrainingArguments,
|
||||||
set_seed,
|
set_seed,
|
||||||
@@ -64,6 +68,9 @@ except (LookupError, OSError):
|
|||||||
with FileLock(".lock") as lock:
|
with FileLock(".lock") as lock:
|
||||||
nltk.download("punkt", quiet=True)
|
nltk.download("punkt", quiet=True)
|
||||||
|
|
||||||
|
# A list of all multilingual tokenizer which require lang attribute.
|
||||||
|
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
@@ -114,6 +121,8 @@ class DataTrainingArguments:
|
|||||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
lang: str = field(default=None, metadata={"help": "Language id for summarization."})
|
||||||
|
|
||||||
dataset_name: Optional[str] = field(
|
dataset_name: Optional[str] = field(
|
||||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||||
)
|
)
|
||||||
@@ -217,12 +226,24 @@ class DataTrainingArguments:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
source_prefix: Optional[str] = field(
|
source_prefix: Optional[str] = field(
|
||||||
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||||
|
)
|
||||||
|
|
||||||
|
forced_bos_token: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The token to force as the first generated token after the decoder_start_token_id."
|
||||||
|
"Useful for multilingual models like mBART where the first generated token"
|
||||||
|
"needs to be the target language token (Usually it is the target language token)"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||||
|
elif self.lang is None:
|
||||||
|
raise ValueError("Need to specify the language.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.train_file is not None:
|
if self.train_file is not None:
|
||||||
extension = self.train_file.split(".")[-1]
|
extension = self.train_file.split(".")[-1]
|
||||||
@@ -370,6 +391,12 @@ def main():
|
|||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
|
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||||
|
if isinstance(tokenizer, MBartTokenizer):
|
||||||
|
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang]
|
||||||
|
else:
|
||||||
|
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang)
|
||||||
|
|
||||||
if model.config.decoder_start_token_id is None:
|
if model.config.decoder_start_token_id is None:
|
||||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||||
|
|
||||||
@@ -406,6 +433,21 @@ def main():
|
|||||||
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
|
||||||
|
assert (
|
||||||
|
data_args.lang is not None
|
||||||
|
), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"
|
||||||
|
|
||||||
|
tokenizer.src_lang = data_args.lang
|
||||||
|
tokenizer.tgt_lang = data_args.lang
|
||||||
|
|
||||||
|
# For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
|
||||||
|
# as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
|
||||||
|
forced_bos_token_id = (
|
||||||
|
tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
|
||||||
|
)
|
||||||
|
model.config.forced_bos_token_id = forced_bos_token_id
|
||||||
|
|
||||||
# Get the column names for input/target.
|
# Get the column names for input/target.
|
||||||
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
|
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
|
||||||
if data_args.text_column is None:
|
if data_args.text_column is None:
|
||||||
@@ -436,14 +478,16 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
def preprocess_function(examples):
|
def preprocess_function(examples):
|
||||||
|
|
||||||
# remove pairs where at least one record is None
|
# remove pairs where at least one record is None
|
||||||
|
|
||||||
inputs, targets = [], []
|
inputs, targets = [], []
|
||||||
for i in range(len(examples[text_column])):
|
for i in range(len(examples[text_column])):
|
||||||
if examples[text_column][i] is not None and examples[summary_column][i] is not None:
|
if examples[text_column][i] is not None and examples[summary_column][i] is not None:
|
||||||
inputs.append(examples[text_column][i])
|
inputs.append(examples[text_column][i])
|
||||||
targets.append(examples[summary_column][i])
|
targets.append(examples[summary_column][i])
|
||||||
|
|
||||||
|
inputs = examples[text_column]
|
||||||
|
targets = examples[summary_column]
|
||||||
inputs = [prefix + inp for inp in inputs]
|
inputs = [prefix + inp for inp in inputs]
|
||||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||||
|
|
||||||
@@ -637,6 +681,9 @@ def main():
|
|||||||
else:
|
else:
|
||||||
kwargs["dataset"] = data_args.dataset_name
|
kwargs["dataset"] = data_args.dataset_name
|
||||||
|
|
||||||
|
if data_args.lang is not None:
|
||||||
|
kwargs["language"] = data_args.lang
|
||||||
|
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
trainer.push_to_hub(**kwargs)
|
trainer.push_to_hub(**kwargs)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user