[examples/translation] support mBART-50 and M2M100 fine-tuning (#11170)
* keep a list of multilingual tokenizers * add forced_bos_token argument
This commit is contained in:
@@ -34,6 +34,9 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
|
M2M100Tokenizer,
|
||||||
|
MBart50Tokenizer,
|
||||||
|
MBart50TokenizerFast,
|
||||||
MBartTokenizer,
|
MBartTokenizer,
|
||||||
MBartTokenizerFast,
|
MBartTokenizerFast,
|
||||||
Seq2SeqTrainer,
|
Seq2SeqTrainer,
|
||||||
@@ -50,6 +53,9 @@ check_min_version("4.6.0.dev0")
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
|
||||||
|
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
@@ -191,6 +197,14 @@ class DataTrainingArguments:
|
|||||||
source_prefix: Optional[str] = field(
|
source_prefix: Optional[str] = field(
|
||||||
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||||
)
|
)
|
||||||
|
forced_bos_token: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
|
||||||
|
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
|
||||||
|
"needs to be the target language token.(Usually it is the target language token)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||||
@@ -325,9 +339,6 @@ def main():
|
|||||||
|
|
||||||
# Set decoder_start_token_id
|
# Set decoder_start_token_id
|
||||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||||
assert (
|
|
||||||
data_args.target_lang is not None and data_args.source_lang is not None
|
|
||||||
), "mBart requires --target_lang and --source_lang"
|
|
||||||
if isinstance(tokenizer, MBartTokenizer):
|
if isinstance(tokenizer, MBartTokenizer):
|
||||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
|
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
|
||||||
else:
|
else:
|
||||||
@@ -352,12 +363,22 @@ def main():
|
|||||||
|
|
||||||
# For translation we set the codes of our source and target languages (only useful for mBART, the others will
|
# For translation we set the codes of our source and target languages (only useful for mBART, the others will
|
||||||
# ignore those attributes).
|
# ignore those attributes).
|
||||||
if isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
|
||||||
if data_args.source_lang is not None:
|
assert data_args.target_lang is not None and data_args.source_lang is not None, (
|
||||||
|
f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --source_lang and "
|
||||||
|
"--target_lang arguments."
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer.src_lang = data_args.source_lang
|
tokenizer.src_lang = data_args.source_lang
|
||||||
if data_args.target_lang is not None:
|
|
||||||
tokenizer.tgt_lang = data_args.target_lang
|
tokenizer.tgt_lang = data_args.target_lang
|
||||||
|
|
||||||
|
# For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
|
||||||
|
# as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
|
||||||
|
forced_bos_token_id = (
|
||||||
|
tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
|
||||||
|
)
|
||||||
|
model.config.foced_bos_token_id = forced_bos_token_id
|
||||||
|
|
||||||
# Get the language codes for input/target.
|
# Get the language codes for input/target.
|
||||||
source_lang = data_args.source_lang.split("_")[0]
|
source_lang = data_args.source_lang.split("_")[0]
|
||||||
target_lang = data_args.target_lang.split("_")[0]
|
target_lang = data_args.target_lang.split("_")[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user