Fix GA loss bugs and add unit test (#35121)
* fix GA bugs and add unit test * narrow down model loss unit test diff gap * format code to make ruff happy * send num_items_in_batch argument to decoder * fix GA loss bug in BertLMHeadModel * use TinyStories-33M to narrow down diff gap * fotmat code * missing .config * avoid add extra args --------- Co-authored-by: kangsheng <kangsheng@meituan.com>
This commit is contained in:
@@ -1325,6 +1325,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
**loss_kwargs,
|
||||||
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
@@ -1375,11 +1376,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
|||||||
|
|
||||||
lm_loss = None
|
lm_loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs)
|
||||||
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
|
||||||
labels = labels[:, 1:].contiguous()
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (prediction_scores,) + outputs[2:]
|
output = (prediction_scores,) + outputs[2:]
|
||||||
|
|||||||
@@ -491,6 +491,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin):
|
|||||||
kwargs_decoder = {
|
kwargs_decoder = {
|
||||||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||||
}
|
}
|
||||||
|
if "num_items_in_batch" in kwargs_encoder:
|
||||||
|
kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None)
|
||||||
|
|
||||||
if encoder_outputs is None:
|
if encoder_outputs is None:
|
||||||
if inputs is None:
|
if inputs is None:
|
||||||
|
|||||||
@@ -3649,9 +3649,6 @@ class Trainer:
|
|||||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||||
|
|
||||||
with self.compute_loss_context_manager():
|
with self.compute_loss_context_manager():
|
||||||
if self.model_accepts_loss_kwargs:
|
|
||||||
loss = self.compute_loss(model, inputs)
|
|
||||||
else:
|
|
||||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||||
|
|
||||||
del inputs
|
del inputs
|
||||||
@@ -5132,10 +5129,6 @@ class Trainer:
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Keep default behavior the same
|
|
||||||
if not self.model_accepts_loss_kwargs:
|
|
||||||
return batch_samples, None
|
|
||||||
|
|
||||||
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
|
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
|
||||||
# For now we don't support object detection
|
# For now we don't support object detection
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -750,11 +750,102 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_gradient_accumulation_loss_alignment(self):
|
def test_gradient_accumulation_loss_alignment_with_model_loss(self):
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
model_name = "distilgpt2"
|
model_name = "nickypro/tinyllama-110M"
|
||||||
|
dataset_name = "wikitext"
|
||||||
|
dataset_config = "wikitext-2-raw-v1"
|
||||||
|
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
|
||||||
|
dataset = dataset.train_test_split(test_size=0.2)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
def tokenize_function(examples):
|
||||||
|
return tokenizer(examples["text"], max_length=128, padding="max_length", truncation=True)
|
||||||
|
|
||||||
|
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
|
||||||
|
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
|
||||||
|
base_loss_callback = StoreLossCallback()
|
||||||
|
|
||||||
|
args_kwargs = {
|
||||||
|
"report_to": "none",
|
||||||
|
"logging_steps": 1,
|
||||||
|
"max_steps": 20,
|
||||||
|
"learning_rate": 3e-4,
|
||||||
|
"disable_tqdm": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
args = TrainingArguments(
|
||||||
|
"./generation",
|
||||||
|
**args_kwargs,
|
||||||
|
)
|
||||||
|
trainer = Trainer(
|
||||||
|
model,
|
||||||
|
args,
|
||||||
|
train_dataset=tokenized_dataset["train"],
|
||||||
|
callbacks=[base_loss_callback],
|
||||||
|
data_collator=data_collator,
|
||||||
|
)
|
||||||
|
assert trainer.model_accepts_loss_kwargs
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
grad_accum_loss_callback = StoreLossCallback()
|
||||||
|
args = TrainingArguments(
|
||||||
|
"./generation",
|
||||||
|
**args_kwargs,
|
||||||
|
gradient_accumulation_steps=2,
|
||||||
|
per_device_train_batch_size=4,
|
||||||
|
)
|
||||||
|
set_seed(42)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
trainer = Trainer(
|
||||||
|
model,
|
||||||
|
args,
|
||||||
|
train_dataset=tokenized_dataset["train"],
|
||||||
|
callbacks=[grad_accum_loss_callback],
|
||||||
|
data_collator=data_collator,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
set_seed(42)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
broken_loss_callback = StoreLossCallback()
|
||||||
|
trainer = Trainer(
|
||||||
|
model,
|
||||||
|
args,
|
||||||
|
train_dataset=tokenized_dataset["train"],
|
||||||
|
callbacks=[broken_loss_callback],
|
||||||
|
data_collator=data_collator,
|
||||||
|
)
|
||||||
|
# disable model_accepts_loss_kwargs
|
||||||
|
trainer.model_accepts_loss_kwargs = False
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Calculate the difference between the base loss and the grad_accum loss
|
||||||
|
diff_truth = [
|
||||||
|
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
|
||||||
|
]
|
||||||
|
diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
|
||||||
|
|
||||||
|
# all diff truth should be quite close
|
||||||
|
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
|
||||||
|
|
||||||
|
# max diff broken should be very off
|
||||||
|
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_gradient_accumulation_loss_alignment_with_loss_func(self):
|
||||||
|
set_seed(42)
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
model_name = "roneneldan/TinyStories-33M"
|
||||||
dataset_name = "wikitext"
|
dataset_name = "wikitext"
|
||||||
dataset_config = "wikitext-2-raw-v1"
|
dataset_config = "wikitext-2-raw-v1"
|
||||||
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
|
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
|
||||||
@@ -836,15 +927,16 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
# Calculate the difference between the base loss and the grad_accum loss
|
# Calculate the difference between the base loss and the grad_accum loss
|
||||||
diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)]
|
diff_truth = [
|
||||||
diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
|
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
|
||||||
# These should be quite close
|
]
|
||||||
for diff in diff_truth:
|
diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
|
||||||
self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")
|
|
||||||
|
|
||||||
# These should be very off
|
# all diff truth should be quite close
|
||||||
for diff in diff_broken:
|
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
|
||||||
self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
|
|
||||||
|
# max diff broken should be very off
|
||||||
|
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
|
||||||
|
|
||||||
def test_gradient_accumulation(self):
|
def test_gradient_accumulation(self):
|
||||||
# Training with half the batch size but accumulation steps as 2 should give the same training losses.
|
# Training with half the batch size but accumulation steps as 2 should give the same training losses.
|
||||||
|
|||||||
Reference in New Issue
Block a user