Update trainer for easier handling of accumulate, compile fixes, and proper reporting (#34511)
* Update trainer for easier handling of accumulate + proper reporting * test * Fixup tests * Full fix * Fix style * rm comment * Fix tests * Minimize test + remove py 311 check * Unused import * Forward contrib credits from discussions * Fix reported metrics * Refactor, good as it's going to get * rm pad tok id check * object detection and audio are being annoying * Fin * Fin x2 --------- Co-authored-by: Gyanateet Dutta <Ryukijano@users.noreply.github.com>
This commit is contained in:
@@ -272,6 +272,19 @@ class RepeatDataset:
|
||||
return {"input_ids": self.x, "labels": self.x}
|
||||
|
||||
|
||||
class SequenceClassificationDataset:
|
||||
def __init__(self, length=64, vocab_size=100, num_labels=5):
|
||||
self.length = length
|
||||
self.sequences = [torch.randint(0, vocab_size, (64,)).tolist() for _ in range(length)]
|
||||
self.labels = torch.randint(0, num_labels, (length,)).tolist()
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i):
|
||||
return {"input_ids": self.sequences[i], "label": self.labels[i]}
|
||||
|
||||
|
||||
class DynamicShapesDataset:
|
||||
def __init__(self, length=64, seed=42, batch_size=8):
|
||||
self.length = length
|
||||
@@ -1144,6 +1157,23 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, 10)
|
||||
|
||||
def test_torch_compile_loss_func_compatibility(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
torch_compile=True,
|
||||
max_steps=1, # compile happens on the first step
|
||||
)
|
||||
trainer = Trainer(model=tiny_llama, args=args, train_dataset=train_dataset) # noqa
|
||||
trainer.train()
|
||||
|
||||
@require_peft
|
||||
@require_bitsandbytes
|
||||
def test_bnb_compile(self):
|
||||
@@ -3676,9 +3706,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||
|
||||
def test_accelerator_config_from_yaml(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
@@ -3691,8 +3718,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
"even_batches": False,
|
||||
"use_seedable_sampler": False,
|
||||
}
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
|
||||
json.dump(accelerator_config, f)
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
@@ -3706,9 +3731,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
||||
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||
|
||||
def test_accelerator_config_from_dataclass(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
@@ -3754,10 +3776,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||
self.assertEqual(trainer.args.gradient_accumulation_steps, 10)
|
||||
|
||||
def test_accelerator_config_from_partial(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
|
||||
Reference in New Issue
Block a user