From 1ccca8f48c493a4804921da66f65023e3f9c8d9c Mon Sep 17 00:00:00 2001 From: kang sheng Date: Mon, 9 Dec 2024 16:57:41 +0800 Subject: [PATCH] Fix GA loss bugs and add unit test (#35121) * fix GA bugs and add unit test * narrow down model loss unit test diff gap * format code to make ruff happy * send num_items_in_batch argument to decoder * fix GA loss bug in BertLMHeadModel * use TinyStories-33M to narrow down diff gap * fotmat code * missing .config * avoid add extra args --------- Co-authored-by: kangsheng --- src/transformers/models/bert/modeling_bert.py | 7 +- .../modeling_speech_encoder_decoder.py | 2 + src/transformers/trainer.py | 9 +- tests/trainer/test_trainer.py | 112 ++++++++++++++++-- 4 files changed, 107 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 6b05fa6481..e311f93b6c 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1325,6 +1325,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **loss_kwargs, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -1375,11 +1376,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin): lm_loss = None if labels is not None: - # we are doing next-token prediction; shift prediction scores and input ids by one - shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs) if not return_dict: output = (prediction_scores,) + outputs[2:] diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 0d2b911beb..3bff8f6acd 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -491,6 +491,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin): kwargs_decoder = { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } + if "num_items_in_batch" in kwargs_encoder: + kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None) if encoder_outputs is None: if inputs is None: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index af908e48e4..f7d7948180 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3649,10 +3649,7 @@ class Trainer: return loss_mb.reduce_mean().detach().to(self.args.device) with self.compute_loss_context_manager(): - if self.model_accepts_loss_kwargs: - loss = self.compute_loss(model, inputs) - else: - loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) + loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) del inputs if ( @@ -5132,10 +5129,6 @@ class Trainer: except StopIteration: break - # Keep default behavior the same - if not self.model_accepts_loss_kwargs: - return batch_samples, None - if len(batch_samples) > 0 and "labels" in batch_samples[0]: # For now we don't support object detection try: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index f7b4a8637b..d33be27897 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -750,11 +750,102 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): self.check_trained_model(trainer.model, alternate_seed=True) @slow - def test_gradient_accumulation_loss_alignment(self): + def test_gradient_accumulation_loss_alignment_with_model_loss(self): set_seed(42) import datasets - model_name = "distilgpt2" + model_name = "nickypro/tinyllama-110M" + dataset_name = "wikitext" + dataset_config = "wikitext-2-raw-v1" + dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]") + dataset = dataset.train_test_split(test_size=0.2) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + tokenizer.pad_token = tokenizer.eos_token + + def tokenize_function(examples): + return tokenizer(examples["text"], max_length=128, padding="max_length", truncation=True) + + tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) + + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + model = AutoModelForCausalLM.from_pretrained(model_name) + + base_loss_callback = StoreLossCallback() + + args_kwargs = { + "report_to": "none", + "logging_steps": 1, + "max_steps": 20, + "learning_rate": 3e-4, + "disable_tqdm": True, + } + + args = TrainingArguments( + "./generation", + **args_kwargs, + ) + trainer = Trainer( + model, + args, + train_dataset=tokenized_dataset["train"], + callbacks=[base_loss_callback], + data_collator=data_collator, + ) + assert trainer.model_accepts_loss_kwargs + trainer.train() + + grad_accum_loss_callback = StoreLossCallback() + args = TrainingArguments( + "./generation", + **args_kwargs, + gradient_accumulation_steps=2, + per_device_train_batch_size=4, + ) + set_seed(42) + model = AutoModelForCausalLM.from_pretrained(model_name) + trainer = Trainer( + model, + args, + train_dataset=tokenized_dataset["train"], + callbacks=[grad_accum_loss_callback], + data_collator=data_collator, + ) + trainer.train() + + set_seed(42) + model = AutoModelForCausalLM.from_pretrained(model_name) + broken_loss_callback = StoreLossCallback() + trainer = Trainer( + model, + args, + train_dataset=tokenized_dataset["train"], + callbacks=[broken_loss_callback], + data_collator=data_collator, + ) + # disable model_accepts_loss_kwargs + trainer.model_accepts_loss_kwargs = False + trainer.train() + + # Calculate the difference between the base loss and the grad_accum loss + diff_truth = [ + abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses) + ] + diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)] + + # all diff truth should be quite close + self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01") + + # max diff broken should be very off + self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3") + + @slow + def test_gradient_accumulation_loss_alignment_with_loss_func(self): + set_seed(42) + import datasets + + model_name = "roneneldan/TinyStories-33M" dataset_name = "wikitext" dataset_config = "wikitext-2-raw-v1" dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]") @@ -836,15 +927,16 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): trainer.train() # Calculate the difference between the base loss and the grad_accum loss - diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)] - diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)] - # These should be quite close - for diff in diff_truth: - self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1") + diff_truth = [ + abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses) + ] + diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)] - # These should be very off - for diff in diff_broken: - self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1") + # all diff truth should be quite close + self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01") + + # max diff broken should be very off + self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3") def test_gradient_accumulation(self): # Training with half the batch size but accumulation steps as 2 should give the same training losses.