Update trainer for easier handling of accumulate, compile fixes, and proper reporting (#34511)
* Update trainer for easier handling of accumulate + proper reporting * test * Fixup tests * Full fix * Fix style * rm comment * Fix tests * Minimize test + remove py 311 check * Unused import * Forward contrib credits from discussions * Fix reported metrics * Refactor, good as it's going to get * rm pad tok id check * object detection and audio are being annoying * Fin * Fin x2 --------- Co-authored-by: Gyanateet Dutta <Ryukijano@users.noreply.github.com>
This commit is contained in:
@@ -28,7 +28,7 @@ import tempfile
|
|||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache, partial, wraps
|
from functools import partial, wraps
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
from zipfile import is_zipfile
|
from zipfile import is_zipfile
|
||||||
@@ -5014,7 +5014,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
return self.hf_quantizer.is_trainable
|
return self.hf_quantizer.is_trainable
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@lru_cache
|
|
||||||
def loss_function(self):
|
def loss_function(self):
|
||||||
if getattr(self.config, "loss_type", None) is not None:
|
if getattr(self.config, "loss_type", None) is not None:
|
||||||
loss_type = self.config.loss_type
|
loss_type = self.config.loss_type
|
||||||
|
|||||||
@@ -233,7 +233,6 @@ if is_accelerate_available():
|
|||||||
from accelerate.utils import (
|
from accelerate.utils import (
|
||||||
DistributedDataParallelKwargs,
|
DistributedDataParallelKwargs,
|
||||||
DistributedType,
|
DistributedType,
|
||||||
GradientAccumulationPlugin,
|
|
||||||
load_fsdp_model,
|
load_fsdp_model,
|
||||||
load_fsdp_optimizer,
|
load_fsdp_optimizer,
|
||||||
save_fsdp_model,
|
save_fsdp_model,
|
||||||
@@ -601,8 +600,10 @@ class Trainer:
|
|||||||
if not _is_peft_model(unwrapped_model)
|
if not _is_peft_model(unwrapped_model)
|
||||||
else unwrapped_model.get_base_model().forward
|
else unwrapped_model.get_base_model().forward
|
||||||
)
|
)
|
||||||
|
forward_params = inspect.signature(model_forward).parameters
|
||||||
self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model_forward).parameters
|
self.model_accepts_loss_kwargs = (
|
||||||
|
"loss_kwargs" in forward_params and forward_params["loss_kwargs"].kind == inspect.Parameter.VAR_KEYWORD
|
||||||
|
)
|
||||||
|
|
||||||
self.neftune_noise_alpha = args.neftune_noise_alpha
|
self.neftune_noise_alpha = args.neftune_noise_alpha
|
||||||
|
|
||||||
@@ -2444,7 +2445,7 @@ class Trainer:
|
|||||||
update_step += 1
|
update_step += 1
|
||||||
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
|
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
|
||||||
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
|
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
|
||||||
for inputs in batch_samples:
|
for i, inputs in enumerate(batch_samples):
|
||||||
step += 1
|
step += 1
|
||||||
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
|
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
|
||||||
# Since we perform prefetching, we need to manually set sync_gradients
|
# Since we perform prefetching, we need to manually set sync_gradients
|
||||||
@@ -2484,7 +2485,13 @@ class Trainer:
|
|||||||
if step % args.gradient_accumulation_steps == 0:
|
if step % args.gradient_accumulation_steps == 0:
|
||||||
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
|
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
|
||||||
|
|
||||||
with self.accelerator.accumulate(model):
|
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
|
||||||
|
context = (
|
||||||
|
functools.partial(self.accelerator.no_sync, model=model)
|
||||||
|
if i == len(batch_samples) - 1
|
||||||
|
else contextlib.nullcontext
|
||||||
|
)
|
||||||
|
with context():
|
||||||
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
|
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -3636,15 +3643,11 @@ class Trainer:
|
|||||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
else:
|
else:
|
||||||
if num_items_in_batch is not None:
|
|
||||||
if self.compute_loss_func or self.model_accepts_loss_kwargs:
|
|
||||||
loss *= self.args.gradient_accumulation_steps
|
|
||||||
# Average tokens across devices is orthogonal to gradient accumulation
|
|
||||||
if self.args.average_tokens_across_devices:
|
|
||||||
loss *= self.args.world_size
|
|
||||||
self.accelerator.backward(loss, **kwargs)
|
self.accelerator.backward(loss, **kwargs)
|
||||||
|
# Finally we need to normalize the loss for reporting
|
||||||
return loss.detach() / self.args.gradient_accumulation_steps
|
if num_items_in_batch is None:
|
||||||
|
return loss.detach() / self.args.gradient_accumulation_steps
|
||||||
|
return loss.detach()
|
||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||||
"""
|
"""
|
||||||
@@ -3656,9 +3659,6 @@ class Trainer:
|
|||||||
labels = inputs.pop("labels")
|
labels = inputs.pop("labels")
|
||||||
else:
|
else:
|
||||||
labels = None
|
labels = None
|
||||||
if self.args.average_tokens_across_devices and num_items_in_batch is not None:
|
|
||||||
num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device)
|
|
||||||
num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu())
|
|
||||||
if self.model_accepts_loss_kwargs:
|
if self.model_accepts_loss_kwargs:
|
||||||
loss_kwargs = {}
|
loss_kwargs = {}
|
||||||
if num_items_in_batch is not None:
|
if num_items_in_batch is not None:
|
||||||
@@ -3692,6 +3692,9 @@ class Trainer:
|
|||||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||||
|
|
||||||
|
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||||
|
loss *= self.accelerator.num_processes
|
||||||
|
|
||||||
return (loss, outputs) if return_outputs else loss
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
def is_local_process_zero(self) -> bool:
|
def is_local_process_zero(self) -> bool:
|
||||||
@@ -4946,24 +4949,21 @@ class Trainer:
|
|||||||
self.repo.git_push()
|
self.repo.git_push()
|
||||||
|
|
||||||
def create_accelerator_and_postprocess(self):
|
def create_accelerator_and_postprocess(self):
|
||||||
|
# We explicitly don't rely on the `Accelerator` to do gradient accumulation
|
||||||
grad_acc_kwargs = {}
|
grad_acc_kwargs = {}
|
||||||
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
|
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
|
||||||
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
|
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
|
||||||
|
|
||||||
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
|
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
|
||||||
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1:
|
if "num_steps" in grad_acc_kwargs:
|
||||||
# raise because we do not know which setting is intended.
|
if self.args.gradient_accumulation_steps > 1:
|
||||||
raise ValueError(
|
# raise because we do not know which setting is intended.
|
||||||
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
|
raise ValueError(
|
||||||
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
|
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
|
||||||
)
|
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
|
||||||
elif "num_steps" not in grad_acc_kwargs:
|
)
|
||||||
# take the gradient_accumulation_steps setting from TrainingArguments.
|
else:
|
||||||
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps
|
self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
|
||||||
|
|
||||||
grad_acc_kwargs["sync_with_dataloader"] = False
|
|
||||||
|
|
||||||
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
|
|
||||||
|
|
||||||
accelerator_config = self.args.accelerator_config.to_dict()
|
accelerator_config = self.args.accelerator_config.to_dict()
|
||||||
|
|
||||||
@@ -4994,7 +4994,6 @@ class Trainer:
|
|||||||
|
|
||||||
args = {
|
args = {
|
||||||
"deepspeed_plugin": self.args.deepspeed_plugin,
|
"deepspeed_plugin": self.args.deepspeed_plugin,
|
||||||
"gradient_accumulation_plugin": gradient_accumulation_plugin,
|
|
||||||
}
|
}
|
||||||
if is_accelerate_available("0.28.0"):
|
if is_accelerate_available("0.28.0"):
|
||||||
args["dataloader_config"] = dataloader_config
|
args["dataloader_config"] = dataloader_config
|
||||||
@@ -5090,12 +5089,18 @@ class Trainer:
|
|||||||
batch_samples += [next(epoch_iterator)]
|
batch_samples += [next(epoch_iterator)]
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Keep default behavior the same
|
||||||
|
if not self.model_accepts_loss_kwargs:
|
||||||
|
return batch_samples, None
|
||||||
|
|
||||||
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
|
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
|
||||||
# For now we don't support object detection
|
# For now we don't support object detection
|
||||||
try:
|
try:
|
||||||
num_items_in_batch = sum(
|
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
|
||||||
[data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples]
|
except (TypeError, AttributeError):
|
||||||
)
|
|
||||||
except TypeError:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if self.args.average_tokens_across_devices:
|
||||||
|
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
|
||||||
return batch_samples, num_items_in_batch
|
return batch_samples, num_items_in_batch
|
||||||
|
|||||||
@@ -272,6 +272,19 @@ class RepeatDataset:
|
|||||||
return {"input_ids": self.x, "labels": self.x}
|
return {"input_ids": self.x, "labels": self.x}
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceClassificationDataset:
|
||||||
|
def __init__(self, length=64, vocab_size=100, num_labels=5):
|
||||||
|
self.length = length
|
||||||
|
self.sequences = [torch.randint(0, vocab_size, (64,)).tolist() for _ in range(length)]
|
||||||
|
self.labels = torch.randint(0, num_labels, (length,)).tolist()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return {"input_ids": self.sequences[i], "label": self.labels[i]}
|
||||||
|
|
||||||
|
|
||||||
class DynamicShapesDataset:
|
class DynamicShapesDataset:
|
||||||
def __init__(self, length=64, seed=42, batch_size=8):
|
def __init__(self, length=64, seed=42, batch_size=8):
|
||||||
self.length = length
|
self.length = length
|
||||||
@@ -1144,6 +1157,23 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
train_output = trainer.train()
|
train_output = trainer.train()
|
||||||
self.assertEqual(train_output.global_step, 10)
|
self.assertEqual(train_output.global_step, 10)
|
||||||
|
|
||||||
|
def test_torch_compile_loss_func_compatibility(self):
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmp_dir,
|
||||||
|
per_device_train_batch_size=2,
|
||||||
|
torch_compile=True,
|
||||||
|
max_steps=1, # compile happens on the first step
|
||||||
|
)
|
||||||
|
trainer = Trainer(model=tiny_llama, args=args, train_dataset=train_dataset) # noqa
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
@require_peft
|
@require_peft
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
def test_bnb_compile(self):
|
def test_bnb_compile(self):
|
||||||
@@ -3676,9 +3706,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||||
|
|
||||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
|
||||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
|
||||||
|
|
||||||
def test_accelerator_config_from_yaml(self):
|
def test_accelerator_config_from_yaml(self):
|
||||||
# Checks that accelerator kwargs can be passed through
|
# Checks that accelerator kwargs can be passed through
|
||||||
# and the accelerator is initialized respectively
|
# and the accelerator is initialized respectively
|
||||||
@@ -3691,8 +3718,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
"even_batches": False,
|
"even_batches": False,
|
||||||
"use_seedable_sampler": False,
|
"use_seedable_sampler": False,
|
||||||
}
|
}
|
||||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
|
||||||
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
|
|
||||||
json.dump(accelerator_config, f)
|
json.dump(accelerator_config, f)
|
||||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
model = RegressionPreTrainedModel(config)
|
model = RegressionPreTrainedModel(config)
|
||||||
@@ -3706,9 +3731,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
||||||
|
|
||||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
|
||||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
|
||||||
|
|
||||||
def test_accelerator_config_from_dataclass(self):
|
def test_accelerator_config_from_dataclass(self):
|
||||||
# Checks that accelerator kwargs can be passed through
|
# Checks that accelerator kwargs can be passed through
|
||||||
# and the accelerator is initialized respectively
|
# and the accelerator is initialized respectively
|
||||||
@@ -3754,10 +3776,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
|
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
|
||||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
|
self.assertEqual(trainer.args.gradient_accumulation_steps, 10)
|
||||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
|
|
||||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
|
|
||||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
|
||||||
|
|
||||||
def test_accelerator_config_from_partial(self):
|
def test_accelerator_config_from_partial(self):
|
||||||
# Checks that accelerator kwargs can be passed through
|
# Checks that accelerator kwargs can be passed through
|
||||||
|
|||||||
Reference in New Issue
Block a user