Fix GA loss bugs and add unit test (#35121)
* fix GA bugs and add unit test * narrow down model loss unit test diff gap * format code to make ruff happy * send num_items_in_batch argument to decoder * fix GA loss bug in BertLMHeadModel * use TinyStories-33M to narrow down diff gap * fotmat code * missing .config * avoid add extra args --------- Co-authored-by: kangsheng <kangsheng@meituan.com>
This commit is contained in:
@@ -750,11 +750,102 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||
|
||||
@slow
|
||||
def test_gradient_accumulation_loss_alignment(self):
|
||||
def test_gradient_accumulation_loss_alignment_with_model_loss(self):
|
||||
set_seed(42)
|
||||
import datasets
|
||||
|
||||
model_name = "distilgpt2"
|
||||
model_name = "nickypro/tinyllama-110M"
|
||||
dataset_name = "wikitext"
|
||||
dataset_config = "wikitext-2-raw-v1"
|
||||
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
|
||||
dataset = dataset.train_test_split(test_size=0.2)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples["text"], max_length=128, padding="max_length", truncation=True)
|
||||
|
||||
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
||||
base_loss_callback = StoreLossCallback()
|
||||
|
||||
args_kwargs = {
|
||||
"report_to": "none",
|
||||
"logging_steps": 1,
|
||||
"max_steps": 20,
|
||||
"learning_rate": 3e-4,
|
||||
"disable_tqdm": True,
|
||||
}
|
||||
|
||||
args = TrainingArguments(
|
||||
"./generation",
|
||||
**args_kwargs,
|
||||
)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=tokenized_dataset["train"],
|
||||
callbacks=[base_loss_callback],
|
||||
data_collator=data_collator,
|
||||
)
|
||||
assert trainer.model_accepts_loss_kwargs
|
||||
trainer.train()
|
||||
|
||||
grad_accum_loss_callback = StoreLossCallback()
|
||||
args = TrainingArguments(
|
||||
"./generation",
|
||||
**args_kwargs,
|
||||
gradient_accumulation_steps=2,
|
||||
per_device_train_batch_size=4,
|
||||
)
|
||||
set_seed(42)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=tokenized_dataset["train"],
|
||||
callbacks=[grad_accum_loss_callback],
|
||||
data_collator=data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
set_seed(42)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
broken_loss_callback = StoreLossCallback()
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=tokenized_dataset["train"],
|
||||
callbacks=[broken_loss_callback],
|
||||
data_collator=data_collator,
|
||||
)
|
||||
# disable model_accepts_loss_kwargs
|
||||
trainer.model_accepts_loss_kwargs = False
|
||||
trainer.train()
|
||||
|
||||
# Calculate the difference between the base loss and the grad_accum loss
|
||||
diff_truth = [
|
||||
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
|
||||
]
|
||||
diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
|
||||
|
||||
# all diff truth should be quite close
|
||||
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
|
||||
|
||||
# max diff broken should be very off
|
||||
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
|
||||
|
||||
@slow
|
||||
def test_gradient_accumulation_loss_alignment_with_loss_func(self):
|
||||
set_seed(42)
|
||||
import datasets
|
||||
|
||||
model_name = "roneneldan/TinyStories-33M"
|
||||
dataset_name = "wikitext"
|
||||
dataset_config = "wikitext-2-raw-v1"
|
||||
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
|
||||
@@ -836,15 +927,16 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
|
||||
# Calculate the difference between the base loss and the grad_accum loss
|
||||
diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)]
|
||||
diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
|
||||
# These should be quite close
|
||||
for diff in diff_truth:
|
||||
self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")
|
||||
diff_truth = [
|
||||
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
|
||||
]
|
||||
diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
|
||||
|
||||
# These should be very off
|
||||
for diff in diff_broken:
|
||||
self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
|
||||
# all diff truth should be quite close
|
||||
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
|
||||
|
||||
# max diff broken should be very off
|
||||
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
|
||||
|
||||
def test_gradient_accumulation(self):
|
||||
# Training with half the batch size but accumulation steps as 2 should give the same training losses.
|
||||
|
||||
Reference in New Issue
Block a user