Add more rigerous non-slow grad accum tests (#35668)

* Add more rigerous non-slow grad accum tests

* Further nits

* Re-add space

* Readbility

* Use tinystories instead

* Revert transformer diff

* tweak threshs
This commit is contained in:
Zach Mueller
2025-02-12 10:26:21 -05:00
committed by GitHub
parent f869d486d3
commit 1fae54c721

View File

@@ -793,35 +793,34 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train() trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True) self.check_trained_model(trainer.model, alternate_seed=True)
@slow
def test_gradient_accumulation_loss_alignment_with_model_loss(self): def test_gradient_accumulation_loss_alignment_with_model_loss(self):
set_seed(42) set_seed(42)
import datasets import datasets
model_name = "nickypro/tinyllama-110M" model_name = "nickypro/tinyllama-15M"
dataset_name = "wikitext" dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1" dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]") dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:40]")
dataset = dataset.train_test_split(test_size=0.2)
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples): def tokenize_function(examples):
return tokenizer(examples["text"], max_length=128, padding="max_length", truncation=True) return tokenizer(examples["text"], max_length=16, padding="max_length", truncation=True)
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
model = AutoModelForCausalLM.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name)
state_dict = model.state_dict()
base_loss_callback = StoreLossCallback() base_loss_callback = StoreLossCallback()
args_kwargs = { args_kwargs = {
"report_to": "none", "report_to": "none",
"logging_steps": 1, "logging_steps": 1,
"max_steps": 20, "max_steps": 5,
"learning_rate": 3e-4, "learning_rate": 3e-4,
"disable_tqdm": True, "disable_tqdm": True,
} }
@@ -834,7 +833,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
trainer = Trainer( trainer = Trainer(
model, model,
args, args,
train_dataset=tokenized_dataset["train"], train_dataset=tokenized_dataset,
callbacks=[base_loss_callback], callbacks=[base_loss_callback],
data_collator=data_collator, data_collator=data_collator,
) )
@@ -854,19 +853,19 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
trainer = Trainer( trainer = Trainer(
model, model,
args, args,
train_dataset=tokenized_dataset["train"], train_dataset=tokenized_dataset,
callbacks=[grad_accum_loss_callback], callbacks=[grad_accum_loss_callback],
data_collator=data_collator, data_collator=data_collator,
) )
trainer.train() trainer.train()
set_seed(42) set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name) model.load_state_dict(state_dict)
broken_loss_callback = StoreLossCallback() broken_loss_callback = StoreLossCallback()
trainer = Trainer( trainer = Trainer(
model, model,
args, args,
train_dataset=tokenized_dataset["train"], train_dataset=tokenized_dataset,
callbacks=[broken_loss_callback], callbacks=[broken_loss_callback],
data_collator=data_collator, data_collator=data_collator,
) )
@@ -886,16 +885,15 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01") self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
# max diff broken should be very off # max diff broken should be very off
self.assertGreater(max(diff_broken), 2, f"Difference {max(diff_broken)} is not greater than 2") self.assertGreater(max(diff_broken), 1.5, f"Difference {max(diff_broken)} is not greater than 2")
loss_base = sum(base_loss_callback.losses) loss_base = sum(base_loss_callback.losses)
loss_broken = sum(broken_loss_callback.losses) loss_broken = sum(broken_loss_callback.losses)
# mean/sum loss should not vary too much. # mean/sum loss should not vary too much.
relative_diff = abs(loss_base - loss_broken) / max(loss_base, loss_broken) relative_diff = abs(loss_base - loss_broken) / max(loss_base, loss_broken)
self.assertLess(relative_diff, 0.1, f"Relative difference {relative_diff} is not within 0.1") self.assertLess(relative_diff, 0.2, f"Relative difference {relative_diff} is not within 0.2")
@slow
def test_gradient_accumulation_loss_alignment_with_loss_func(self): def test_gradient_accumulation_loss_alignment_with_loss_func(self):
set_seed(42) set_seed(42)
import datasets import datasets
@@ -903,14 +901,15 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
model_name = "roneneldan/TinyStories-33M" model_name = "roneneldan/TinyStories-33M"
dataset_name = "wikitext" dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1" dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]") dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:40]")
dataset = dataset.train_test_split(test_size=0.2)
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_function(examples): tokenizer.pad_token = tokenizer.eos_token
return tokenizer(examples["text"])
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) def tokenize_function(examples):
return tokenizer(examples["text"], max_length=16, padding="max_length", truncation=True)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
@@ -929,7 +928,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
args_kwargs = { args_kwargs = {
"report_to": "none", "report_to": "none",
"logging_steps": 1, "logging_steps": 1,
"max_steps": 20, "max_steps": 5,
"learning_rate": 3e-4, "learning_rate": 3e-4,
"disable_tqdm": True, "disable_tqdm": True,
} }
@@ -942,7 +941,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
trainer = Trainer( trainer = Trainer(
model, model,
args, args,
train_dataset=tokenized_dataset["train"], train_dataset=tokenized_dataset,
callbacks=[base_loss_callback], callbacks=[base_loss_callback],
compute_loss_func=loss_fn, compute_loss_func=loss_fn,
data_collator=data_collator, data_collator=data_collator,
@@ -962,7 +961,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
trainer = Trainer( trainer = Trainer(
model, model,
args, args,
train_dataset=tokenized_dataset["train"], train_dataset=tokenized_dataset,
callbacks=[grad_accum_loss_callback], callbacks=[grad_accum_loss_callback],
compute_loss_func=loss_fn, compute_loss_func=loss_fn,
data_collator=data_collator, data_collator=data_collator,
@@ -976,7 +975,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
trainer = Trainer( trainer = Trainer(
model, model,
args, args,
train_dataset=tokenized_dataset["train"], train_dataset=tokenized_dataset,
callbacks=[broken_loss_callback], callbacks=[broken_loss_callback],
compute_loss_func=loss_fn, compute_loss_func=loss_fn,
data_collator=data_collator, data_collator=data_collator,