From 9f8fa4e9730b8e658bcd5625610cc70f3a019818 Mon Sep 17 00:00:00 2001 From: Eliza Szczechla <3648991+elsanns@users.noreply.github.com> Date: Mon, 22 Mar 2021 20:05:39 +0100 Subject: [PATCH] Use DataCollatorForSeq2Seq in run_summarization in all cases (#10856) Co-authored-by: Eliza --- examples/seq2seq/run_summarization.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/examples/seq2seq/run_summarization.py b/examples/seq2seq/run_summarization.py index 43ae63b8ba..2dd1a0719d 100755 --- a/examples/seq2seq/run_summarization.py +++ b/examples/seq2seq/run_summarization.py @@ -38,7 +38,6 @@ from transformers import ( HfArgumentParser, Seq2SeqTrainer, Seq2SeqTrainingArguments, - default_data_collator, set_seed, ) from transformers.file_utils import is_offline_mode @@ -466,15 +465,12 @@ def main(): # Data collator label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id - if data_args.pad_to_max_length: - data_collator = default_data_collator - else: - data_collator = DataCollatorForSeq2Seq( - tokenizer, - model=model, - label_pad_token_id=label_pad_token_id, - pad_to_multiple_of=8 if training_args.fp16 else None, - ) + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=8 if training_args.fp16 else None, + ) # Metric metric = load_metric("rouge")