Fix GA loss bugs and add unit test (#35121)

* fix GA bugs and add unit test

* narrow down model loss unit test diff gap

* format code to make ruff happy

* send num_items_in_batch argument to decoder

* fix GA loss bug in BertLMHeadModel

* use TinyStories-33M to narrow down diff gap

* fotmat code

* missing .config

* avoid add extra args

---------

Co-authored-by: kangsheng <kangsheng@meituan.com>
This commit is contained in:
kang sheng
2024-12-09 16:57:41 +08:00
committed by GitHub
parent c8c8dffbe4
commit 1ccca8f48c
4 changed files with 107 additions and 23 deletions

View File

@@ -1325,6 +1325,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
**loss_kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -1375,11 +1376,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
lm_loss = None lm_loss = None
if labels is not None: if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs)
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]

View File

@@ -491,6 +491,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin):
kwargs_decoder = { kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
} }
if "num_items_in_batch" in kwargs_encoder:
kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None)
if encoder_outputs is None: if encoder_outputs is None:
if inputs is None: if inputs is None:

View File

@@ -3649,10 +3649,7 @@ class Trainer:
return loss_mb.reduce_mean().detach().to(self.args.device) return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager(): with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs del inputs
if ( if (
@@ -5132,10 +5129,6 @@ class Trainer:
except StopIteration: except StopIteration:
break break
# Keep default behavior the same
if not self.model_accepts_loss_kwargs:
return batch_samples, None
if len(batch_samples) > 0 and "labels" in batch_samples[0]: if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection # For now we don't support object detection
try: try:

View File

@@ -750,11 +750,102 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.check_trained_model(trainer.model, alternate_seed=True) self.check_trained_model(trainer.model, alternate_seed=True)
@slow @slow
def test_gradient_accumulation_loss_alignment(self): def test_gradient_accumulation_loss_alignment_with_model_loss(self):
set_seed(42) set_seed(42)
import datasets import datasets
model_name = "distilgpt2" model_name = "nickypro/tinyllama-110M"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
dataset = dataset.train_test_split(test_size=0.2)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
return tokenizer(examples["text"], max_length=128, padding="max_length", truncation=True)
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
model = AutoModelForCausalLM.from_pretrained(model_name)
base_loss_callback = StoreLossCallback()
args_kwargs = {
"report_to": "none",
"logging_steps": 1,
"max_steps": 20,
"learning_rate": 3e-4,
"disable_tqdm": True,
}
args = TrainingArguments(
"./generation",
**args_kwargs,
)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[base_loss_callback],
data_collator=data_collator,
)
assert trainer.model_accepts_loss_kwargs
trainer.train()
grad_accum_loss_callback = StoreLossCallback()
args = TrainingArguments(
"./generation",
**args_kwargs,
gradient_accumulation_steps=2,
per_device_train_batch_size=4,
)
set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[grad_accum_loss_callback],
data_collator=data_collator,
)
trainer.train()
set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name)
broken_loss_callback = StoreLossCallback()
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[broken_loss_callback],
data_collator=data_collator,
)
# disable model_accepts_loss_kwargs
trainer.model_accepts_loss_kwargs = False
trainer.train()
# Calculate the difference between the base loss and the grad_accum loss
diff_truth = [
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
]
diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
# all diff truth should be quite close
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
# max diff broken should be very off
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
@slow
def test_gradient_accumulation_loss_alignment_with_loss_func(self):
set_seed(42)
import datasets
model_name = "roneneldan/TinyStories-33M"
dataset_name = "wikitext" dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1" dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]") dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
@@ -836,15 +927,16 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train() trainer.train()
# Calculate the difference between the base loss and the grad_accum loss # Calculate the difference between the base loss and the grad_accum loss
diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)] diff_truth = [
diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)] abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
# These should be quite close ]
for diff in diff_truth: diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")
# These should be very off # all diff truth should be quite close
for diff in diff_broken: self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
# max diff broken should be very off
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
def test_gradient_accumulation(self): def test_gradient_accumulation(self):
# Training with half the batch size but accumulation steps as 2 should give the same training losses. # Training with half the batch size but accumulation steps as 2 should give the same training losses.