Add more rigerous non-slow grad accum tests (#35668)
* Add more rigerous non-slow grad accum tests * Further nits * Re-add space * Readbility * Use tinystories instead * Revert transformer diff * tweak threshs
This commit is contained in:
@@ -793,35 +793,34 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_gradient_accumulation_loss_alignment_with_model_loss(self):
|
def test_gradient_accumulation_loss_alignment_with_model_loss(self):
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
model_name = "nickypro/tinyllama-110M"
|
model_name = "nickypro/tinyllama-15M"
|
||||||
dataset_name = "wikitext"
|
dataset_name = "wikitext"
|
||||||
dataset_config = "wikitext-2-raw-v1"
|
dataset_config = "wikitext-2-raw-v1"
|
||||||
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
|
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:40]")
|
||||||
dataset = dataset.train_test_split(test_size=0.2)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
def tokenize_function(examples):
|
def tokenize_function(examples):
|
||||||
return tokenizer(examples["text"], max_length=128, padding="max_length", truncation=True)
|
return tokenizer(examples["text"], max_length=16, padding="max_length", truncation=True)
|
||||||
|
|
||||||
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
|
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
|
||||||
|
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
base_loss_callback = StoreLossCallback()
|
base_loss_callback = StoreLossCallback()
|
||||||
|
|
||||||
args_kwargs = {
|
args_kwargs = {
|
||||||
"report_to": "none",
|
"report_to": "none",
|
||||||
"logging_steps": 1,
|
"logging_steps": 1,
|
||||||
"max_steps": 20,
|
"max_steps": 5,
|
||||||
"learning_rate": 3e-4,
|
"learning_rate": 3e-4,
|
||||||
"disable_tqdm": True,
|
"disable_tqdm": True,
|
||||||
}
|
}
|
||||||
@@ -834,7 +833,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
train_dataset=tokenized_dataset["train"],
|
train_dataset=tokenized_dataset,
|
||||||
callbacks=[base_loss_callback],
|
callbacks=[base_loss_callback],
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
)
|
)
|
||||||
@@ -854,19 +853,19 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
train_dataset=tokenized_dataset["train"],
|
train_dataset=tokenized_dataset,
|
||||||
callbacks=[grad_accum_loss_callback],
|
callbacks=[grad_accum_loss_callback],
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
model.load_state_dict(state_dict)
|
||||||
broken_loss_callback = StoreLossCallback()
|
broken_loss_callback = StoreLossCallback()
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
train_dataset=tokenized_dataset["train"],
|
train_dataset=tokenized_dataset,
|
||||||
callbacks=[broken_loss_callback],
|
callbacks=[broken_loss_callback],
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
)
|
)
|
||||||
@@ -886,16 +885,15 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
|
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
|
||||||
|
|
||||||
# max diff broken should be very off
|
# max diff broken should be very off
|
||||||
self.assertGreater(max(diff_broken), 2, f"Difference {max(diff_broken)} is not greater than 2")
|
self.assertGreater(max(diff_broken), 1.5, f"Difference {max(diff_broken)} is not greater than 2")
|
||||||
|
|
||||||
loss_base = sum(base_loss_callback.losses)
|
loss_base = sum(base_loss_callback.losses)
|
||||||
loss_broken = sum(broken_loss_callback.losses)
|
loss_broken = sum(broken_loss_callback.losses)
|
||||||
|
|
||||||
# mean/sum loss should not vary too much.
|
# mean/sum loss should not vary too much.
|
||||||
relative_diff = abs(loss_base - loss_broken) / max(loss_base, loss_broken)
|
relative_diff = abs(loss_base - loss_broken) / max(loss_base, loss_broken)
|
||||||
self.assertLess(relative_diff, 0.1, f"Relative difference {relative_diff} is not within 0.1")
|
self.assertLess(relative_diff, 0.2, f"Relative difference {relative_diff} is not within 0.2")
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_gradient_accumulation_loss_alignment_with_loss_func(self):
|
def test_gradient_accumulation_loss_alignment_with_loss_func(self):
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
import datasets
|
import datasets
|
||||||
@@ -903,14 +901,15 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
model_name = "roneneldan/TinyStories-33M"
|
model_name = "roneneldan/TinyStories-33M"
|
||||||
dataset_name = "wikitext"
|
dataset_name = "wikitext"
|
||||||
dataset_config = "wikitext-2-raw-v1"
|
dataset_config = "wikitext-2-raw-v1"
|
||||||
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
|
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:40]")
|
||||||
dataset = dataset.train_test_split(test_size=0.2)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
def tokenize_function(examples):
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
return tokenizer(examples["text"])
|
|
||||||
|
|
||||||
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
|
def tokenize_function(examples):
|
||||||
|
return tokenizer(examples["text"], max_length=16, padding="max_length", truncation=True)
|
||||||
|
|
||||||
|
tokenized_dataset = dataset.map(tokenize_function, batched=True)
|
||||||
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
@@ -929,7 +928,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
args_kwargs = {
|
args_kwargs = {
|
||||||
"report_to": "none",
|
"report_to": "none",
|
||||||
"logging_steps": 1,
|
"logging_steps": 1,
|
||||||
"max_steps": 20,
|
"max_steps": 5,
|
||||||
"learning_rate": 3e-4,
|
"learning_rate": 3e-4,
|
||||||
"disable_tqdm": True,
|
"disable_tqdm": True,
|
||||||
}
|
}
|
||||||
@@ -942,7 +941,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
train_dataset=tokenized_dataset["train"],
|
train_dataset=tokenized_dataset,
|
||||||
callbacks=[base_loss_callback],
|
callbacks=[base_loss_callback],
|
||||||
compute_loss_func=loss_fn,
|
compute_loss_func=loss_fn,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
@@ -962,7 +961,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
train_dataset=tokenized_dataset["train"],
|
train_dataset=tokenized_dataset,
|
||||||
callbacks=[grad_accum_loss_callback],
|
callbacks=[grad_accum_loss_callback],
|
||||||
compute_loss_func=loss_fn,
|
compute_loss_func=loss_fn,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
@@ -976,7 +975,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
train_dataset=tokenized_dataset["train"],
|
train_dataset=tokenized_dataset,
|
||||||
callbacks=[broken_loss_callback],
|
callbacks=[broken_loss_callback],
|
||||||
compute_loss_func=loss_fn,
|
compute_loss_func=loss_fn,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
|
|||||||
Reference in New Issue
Block a user