Update trainer for easier handling of accumulate, compile fixes, and proper reporting (#34511)

* Update trainer for easier handling of accumulate + proper reporting

* test

* Fixup tests

* Full fix

* Fix style

* rm comment

* Fix tests

* Minimize test + remove py 311 check

* Unused import

* Forward contrib credits from discussions

* Fix reported metrics

* Refactor, good as it's going to get

* rm pad tok id check

* object detection and audio are being annoying

* Fin

* Fin x2

---------

Co-authored-by: Gyanateet Dutta <Ryukijano@users.noreply.github.com>
This commit is contained in:
Zach Mueller
2024-11-04 07:47:34 -05:00
committed by GitHub
parent 33868a057c
commit ef976a7e18
3 changed files with 71 additions and 48 deletions

View File

@@ -28,7 +28,7 @@ import tempfile
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache, partial, wraps from functools import partial, wraps
from threading import Thread from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from zipfile import is_zipfile from zipfile import is_zipfile
@@ -5014,7 +5014,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return self.hf_quantizer.is_trainable return self.hf_quantizer.is_trainable
@property @property
@lru_cache
def loss_function(self): def loss_function(self):
if getattr(self.config, "loss_type", None) is not None: if getattr(self.config, "loss_type", None) is not None:
loss_type = self.config.loss_type loss_type = self.config.loss_type

View File

@@ -233,7 +233,6 @@ if is_accelerate_available():
from accelerate.utils import ( from accelerate.utils import (
DistributedDataParallelKwargs, DistributedDataParallelKwargs,
DistributedType, DistributedType,
GradientAccumulationPlugin,
load_fsdp_model, load_fsdp_model,
load_fsdp_optimizer, load_fsdp_optimizer,
save_fsdp_model, save_fsdp_model,
@@ -601,8 +600,10 @@ class Trainer:
if not _is_peft_model(unwrapped_model) if not _is_peft_model(unwrapped_model)
else unwrapped_model.get_base_model().forward else unwrapped_model.get_base_model().forward
) )
forward_params = inspect.signature(model_forward).parameters
self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model_forward).parameters self.model_accepts_loss_kwargs = (
"loss_kwargs" in forward_params and forward_params["loss_kwargs"].kind == inspect.Parameter.VAR_KEYWORD
)
self.neftune_noise_alpha = args.neftune_noise_alpha self.neftune_noise_alpha = args.neftune_noise_alpha
@@ -2444,7 +2445,7 @@ class Trainer:
update_step += 1 update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
for inputs in batch_samples: for i, inputs in enumerate(batch_samples):
step += 1 step += 1
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
# Since we perform prefetching, we need to manually set sync_gradients # Since we perform prefetching, we need to manually set sync_gradients
@@ -2484,7 +2485,13 @@ class Trainer:
if step % args.gradient_accumulation_steps == 0: if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control) self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
with self.accelerator.accumulate(model): # We explicitly want to avoid relying on `accelerator.accumulate` for generation training
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i == len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch) tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
if ( if (
@@ -3636,15 +3643,11 @@ class Trainer:
with amp.scale_loss(loss, self.optimizer) as scaled_loss: with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
else: else:
if num_items_in_batch is not None:
if self.compute_loss_func or self.model_accepts_loss_kwargs:
loss *= self.args.gradient_accumulation_steps
# Average tokens across devices is orthogonal to gradient accumulation
if self.args.average_tokens_across_devices:
loss *= self.args.world_size
self.accelerator.backward(loss, **kwargs) self.accelerator.backward(loss, **kwargs)
# Finally we need to normalize the loss for reporting
if num_items_in_batch is None:
return loss.detach() / self.args.gradient_accumulation_steps return loss.detach() / self.args.gradient_accumulation_steps
return loss.detach()
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
""" """
@@ -3656,9 +3659,6 @@ class Trainer:
labels = inputs.pop("labels") labels = inputs.pop("labels")
else: else:
labels = None labels = None
if self.args.average_tokens_across_devices and num_items_in_batch is not None:
num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device)
num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu())
if self.model_accepts_loss_kwargs: if self.model_accepts_loss_kwargs:
loss_kwargs = {} loss_kwargs = {}
if num_items_in_batch is not None: if num_items_in_batch is not None:
@@ -3692,6 +3692,9 @@ class Trainer:
# We don't use .loss here since the model may return tuples instead of ModelOutput. # We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes
return (loss, outputs) if return_outputs else loss return (loss, outputs) if return_outputs else loss
def is_local_process_zero(self) -> bool: def is_local_process_zero(self) -> bool:
@@ -4946,24 +4949,21 @@ class Trainer:
self.repo.git_push() self.repo.git_push()
def create_accelerator_and_postprocess(self): def create_accelerator_and_postprocess(self):
# We explicitly don't rely on the `Accelerator` to do gradient accumulation
grad_acc_kwargs = {} grad_acc_kwargs = {}
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None: if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs # check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1: if "num_steps" in grad_acc_kwargs:
if self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended. # raise because we do not know which setting is intended.
raise ValueError( raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
) )
elif "num_steps" not in grad_acc_kwargs: else:
# take the gradient_accumulation_steps setting from TrainingArguments. self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps
grad_acc_kwargs["sync_with_dataloader"] = False
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
accelerator_config = self.args.accelerator_config.to_dict() accelerator_config = self.args.accelerator_config.to_dict()
@@ -4994,7 +4994,6 @@ class Trainer:
args = { args = {
"deepspeed_plugin": self.args.deepspeed_plugin, "deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,
} }
if is_accelerate_available("0.28.0"): if is_accelerate_available("0.28.0"):
args["dataloader_config"] = dataloader_config args["dataloader_config"] = dataloader_config
@@ -5090,12 +5089,18 @@ class Trainer:
batch_samples += [next(epoch_iterator)] batch_samples += [next(epoch_iterator)]
except StopIteration: except StopIteration:
break break
# Keep default behavior the same
if not self.model_accepts_loss_kwargs:
return batch_samples, None
if len(batch_samples) > 0 and "labels" in batch_samples[0]: if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection # For now we don't support object detection
try: try:
num_items_in_batch = sum( num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
[data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples] except (TypeError, AttributeError):
)
except TypeError:
pass pass
if self.args.average_tokens_across_devices:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
return batch_samples, num_items_in_batch return batch_samples, num_items_in_batch

View File

@@ -272,6 +272,19 @@ class RepeatDataset:
return {"input_ids": self.x, "labels": self.x} return {"input_ids": self.x, "labels": self.x}
class SequenceClassificationDataset:
def __init__(self, length=64, vocab_size=100, num_labels=5):
self.length = length
self.sequences = [torch.randint(0, vocab_size, (64,)).tolist() for _ in range(length)]
self.labels = torch.randint(0, num_labels, (length,)).tolist()
def __len__(self):
return self.length
def __getitem__(self, i):
return {"input_ids": self.sequences[i], "label": self.labels[i]}
class DynamicShapesDataset: class DynamicShapesDataset:
def __init__(self, length=64, seed=42, batch_size=8): def __init__(self, length=64, seed=42, batch_size=8):
self.length = length self.length = length
@@ -1144,6 +1157,23 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train() train_output = trainer.train()
self.assertEqual(train_output.global_step, 10) self.assertEqual(train_output.global_step, 10)
def test_torch_compile_loss_func_compatibility(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
tmp_dir,
per_device_train_batch_size=2,
torch_compile=True,
max_steps=1, # compile happens on the first step
)
trainer = Trainer(model=tiny_llama, args=args, train_dataset=train_dataset) # noqa
trainer.train()
@require_peft @require_peft
@require_bitsandbytes @require_bitsandbytes
def test_bnb_compile(self): def test_bnb_compile(self):
@@ -3676,9 +3706,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True) self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
def test_accelerator_config_from_yaml(self): def test_accelerator_config_from_yaml(self):
# Checks that accelerator kwargs can be passed through # Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively # and the accelerator is initialized respectively
@@ -3691,8 +3718,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
"even_batches": False, "even_batches": False,
"use_seedable_sampler": False, "use_seedable_sampler": False,
} }
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
json.dump(accelerator_config, f) json.dump(accelerator_config, f)
config = RegressionModelConfig(a=1.5, b=2.5) config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config) model = RegressionPreTrainedModel(config)
@@ -3706,9 +3731,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False) self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
def test_accelerator_config_from_dataclass(self): def test_accelerator_config_from_dataclass(self):
# Checks that accelerator kwargs can be passed through # Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively # and the accelerator is initialized respectively
@@ -3754,10 +3776,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config) args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10) self.assertEqual(trainer.args.gradient_accumulation_steps, 10)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
def test_accelerator_config_from_partial(self): def test_accelerator_config_from_partial(self):
# Checks that accelerator kwargs can be passed through # Checks that accelerator kwargs can be passed through