[examples/seq2seq] support label smoothing (#9844)

* add prepare_decoder_input_ids_from_labels in s2s models

* support lbl smoothing and enc/emb freezing

* fix freezing

* use pad_token_id from config

* remove embed freezing and add warning

* prepare decoder_input_ids inside DataCollatorForSeq2Seq
This commit is contained in:
Suraj Patil
2021-02-05 23:21:57 +05:30
committed by GitHub
parent b9720dd6f2
commit 1cd16512dc
10 changed files with 46 additions and 1 deletions

View File

@@ -384,6 +384,12 @@ def main():
max_target_length = data_args.max_target_length
padding = "max_length" if data_args.pad_to_max_length else False
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
logger.warn(
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
)
def preprocess_function(examples):
if data_args.task.startswith("translation"):
inputs = [ex[source_lang] for ex in examples["translation"]]
@@ -440,6 +446,7 @@ def main():
else:
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
)