From 067c4a310dd36d0472d4a587145e94d20bf64964 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 14 Nov 2023 14:54:44 -0500 Subject: [PATCH] Have seq2seq just use gather (#27025) * Have seq2seq just use gather * Change * Reset after * Make slow * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Clean * Simplify and just use gather * Update tests/trainer/test_trainer_seq2seq.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * gather always for seq2seq --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/trainer.py | 12 ++++-- src/transformers/trainer_seq2seq.py | 4 +- tests/trainer/test_trainer_seq2seq.py | 61 ++++++++++++++++++++++++++- 3 files changed, 70 insertions(+), 7 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 40159d8163..51c60ce690 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3208,13 +3208,13 @@ class Trainer: # Update containers on host if loss is not None: - losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size))) + losses = self.gather_function((loss.repeat(batch_size))) losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100) if labels is not None: labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) if inputs_decode is not None: inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) - inputs_decode = self.accelerator.gather_for_metrics((inputs_decode)) + inputs_decode = self.gather_function((inputs_decode)) inputs_host = ( inputs_decode if inputs_host is None @@ -3224,11 +3224,11 @@ class Trainer: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) - logits = self.accelerator.gather_for_metrics((logits)) + logits = self.gather_function((logits)) preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) if labels is not None: - labels = self.accelerator.gather_for_metrics((labels)) + labels = self.gather_function((labels)) labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) @@ -3261,6 +3261,8 @@ class Trainer: # Set back to None to begin a new accumulation losses_host, preds_host, inputs_host, labels_host = None, None, None, None + # After all calls to `.gather_function`, reset to `gather_for_metrics`: + self.gather_function = self.accelerator.gather_for_metrics if args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop delattr(self, "_past") @@ -3930,6 +3932,8 @@ class Trainer: deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin, ) + # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag + self.gather_function = self.accelerator.gather_for_metrics # deepspeed and accelerate flags covering both trainer args and accelerate launcher self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 13d407bec4..9f6bf13245 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -160,8 +160,9 @@ class Seq2SeqTrainer(Trainer): gen_kwargs["max_length"] = self.args.generation_max_length if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None: gen_kwargs["num_beams"] = self.args.generation_num_beams + # We don't want to drop samples in general + self.gather_function = self.accelerator.gather self._gen_kwargs = gen_kwargs - return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) def predict( @@ -223,6 +224,7 @@ class Seq2SeqTrainer(Trainer): gen_kwargs["max_length"] = self.args.generation_max_length if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None: gen_kwargs["num_beams"] = self.args.generation_num_beams + self.gather_function = self.accelerator.gather self._gen_kwargs = gen_kwargs return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) diff --git a/tests/trainer/test_trainer_seq2seq.py b/tests/trainer/test_trainer_seq2seq.py index 918c221558..3f875e6d36 100644 --- a/tests/trainer/test_trainer_seq2seq.py +++ b/tests/trainer/test_trainer_seq2seq.py @@ -12,8 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments +from transformers import ( + AutoModelForSeq2SeqLM, + BertTokenizer, + DataCollatorForSeq2Seq, + EncoderDecoderModel, + GenerationConfig, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + T5Tokenizer, +) from transformers.testing_utils import TestCasePlus, require_torch, slow from transformers.utils import is_datasets_available @@ -124,3 +132,52 @@ class Seq2seqTrainerTester(TestCasePlus): # start training trainer.train() + + @slow + @require_torch + def test_return_sequences(self): + # Tests that the number of generated sequences is correct when num_return_sequences > 1 + # and essentially ensuring that `accelerator.gather()` is used instead of `gather_for_metrics` + INPUT_COLUMN = "question" + TARGET_COLUMN = "answer" + MAX_INPUT_LENGTH = 256 + MAX_TARGET_LENGTH = 256 + + dataset = datasets.load_dataset("gsm8k", "main", split="train[:38]") + model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") + tokenizer = T5Tokenizer.from_pretrained("t5-small") + data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt", padding="longest") + gen_config = GenerationConfig.from_pretrained( + "t5-small", max_length=None, min_length=None, max_new_tokens=256, min_new_tokens=1, num_beams=5 + ) + + training_args = Seq2SeqTrainingArguments(".", predict_with_generate=True) + + trainer = Seq2SeqTrainer( + model=model, + args=training_args, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=lambda x: {"samples": x[0].shape[0]}, + ) + + def prepare_data(examples): + # Remove pairs where at least one record is none + inputs = examples[INPUT_COLUMN] + targets = examples[TARGET_COLUMN] + + model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True) + labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, truncation=True) + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + prepared_dataset = dataset.map(prepare_data, batched=True, remove_columns=[INPUT_COLUMN, TARGET_COLUMN]) + dataset_len = len(prepared_dataset) # 38 + + for num_return_sequences in range(3, 0, -1): + gen_config.num_return_sequences = num_return_sequences + metrics = trainer.evaluate(eval_dataset=prepared_dataset, generation_config=gen_config) + assert ( + metrics["eval_samples"] == dataset_len * num_return_sequences + ), f"Got {metrics['eval_samples']}, expected: {dataset_len * num_return_sequences}"