Support compilation via Torchdynamo, AOT Autograd, NVFuser (#17308)
* Support compilation via Torchdynamo, AOT Autograd, NVFuser * Address comments * Lint * Stas comments - missing quality test * Lintere * Quality test * Doc lint * Reset CUDA peak mem * Add CustomTrainer * require a single gpu Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
@@ -70,6 +70,7 @@ from .utils import (
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torchaudio_available,
|
||||
is_torchdynamo_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
@@ -464,6 +465,11 @@ else:
|
||||
jax_device = None
|
||||
|
||||
|
||||
def require_torchdynamo(test_case):
|
||||
"""Decorator marking a test that requires TorchDynamo"""
|
||||
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
|
||||
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
||||
|
||||
@@ -139,8 +139,10 @@ from .utils import (
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_torch_tpu_available,
|
||||
is_torchdynamo_available,
|
||||
logging,
|
||||
)
|
||||
from .utils.generic import ContextManagers
|
||||
|
||||
|
||||
_is_torch_generator_available = False
|
||||
@@ -2172,6 +2174,32 @@ class Trainer:
|
||||
|
||||
return inputs
|
||||
|
||||
def compute_loss_context_manager(self):
|
||||
"""
|
||||
A helper wrapper to group together context managers.
|
||||
"""
|
||||
return ContextManagers(
|
||||
[
|
||||
self.torchdynamo_smart_context_manager(),
|
||||
self.autocast_smart_context_manager(),
|
||||
]
|
||||
)
|
||||
|
||||
def torchdynamo_smart_context_manager(self):
|
||||
"""
|
||||
A helper wrapper that creates an appropriate context manager for `torchdynamo`.
|
||||
"""
|
||||
ctx_manager = contextlib.nullcontext()
|
||||
if is_torchdynamo_available():
|
||||
import torchdynamo
|
||||
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
|
||||
|
||||
if self.args.torchdynamo == "eager":
|
||||
ctx_manager = torchdynamo.optimize("eager")
|
||||
elif self.args.torchdynamo == "nvfuser":
|
||||
ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
|
||||
return ctx_manager
|
||||
|
||||
def autocast_smart_context_manager(self):
|
||||
"""
|
||||
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
|
||||
@@ -2213,7 +2241,7 @@ class Trainer:
|
||||
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
|
||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||
|
||||
with self.autocast_smart_context_manager():
|
||||
with self.compute_loss_context_manager():
|
||||
loss = self.compute_loss(model, inputs)
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
@@ -2907,7 +2935,7 @@ class Trainer:
|
||||
logits = smp_nested_concat(logits_mb)
|
||||
else:
|
||||
if has_labels:
|
||||
with self.autocast_smart_context_manager():
|
||||
with self.compute_loss_context_manager():
|
||||
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
||||
loss = loss.mean().detach()
|
||||
|
||||
@@ -2917,7 +2945,7 @@ class Trainer:
|
||||
logits = outputs[1:]
|
||||
else:
|
||||
loss = None
|
||||
with self.autocast_smart_context_manager():
|
||||
with self.compute_loss_context_manager():
|
||||
outputs = model(**inputs)
|
||||
if isinstance(outputs, dict):
|
||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
|
||||
|
||||
@@ -183,7 +183,7 @@ class Seq2SeqTrainer(Trainer):
|
||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
||||
|
||||
with torch.no_grad():
|
||||
with self.autocast_smart_context_manager():
|
||||
with self.compute_loss_context_manager():
|
||||
outputs = model(**inputs)
|
||||
if has_labels:
|
||||
if self.label_smoother is not None:
|
||||
|
||||
@@ -450,6 +450,9 @@ class TrainingArguments:
|
||||
full_determinism (`bool`, *optional*, defaults to `False`)
|
||||
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
|
||||
distributed training
|
||||
torchdynamo (`str`, *optional*):
|
||||
The token that is used to set the backend compiler for TorchDynamo. Possible choices are ["eager",
|
||||
"nvfuser]. This is an experimental API and subject to change.
|
||||
"""
|
||||
|
||||
output_dir: str = field(
|
||||
@@ -881,6 +884,20 @@ class TrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
torchdynamo: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Sets up the backend compiler for TorchDynamo. TorchDynamo is a Python level JIT compiler designed to"
|
||||
" make unmodified PyTorch programs faster. TorchDynamo dynamically modifies the Python bytecode right"
|
||||
" before its executed. It rewrites Python bytecode to extract sequences of PyTorch operations"
|
||||
" and lifts them up into Fx graph. We can then pass these Fx graphs to other backend compilers. There"
|
||||
" are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging."
|
||||
" nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models."
|
||||
),
|
||||
"choices": ["eager", "nvfuser"],
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
|
||||
|
||||
@@ -130,6 +130,7 @@ from .import_utils import (
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torchaudio_available,
|
||||
is_torchdynamo_available,
|
||||
is_training_run_on_sagemaker,
|
||||
is_vision_available,
|
||||
requires_backends,
|
||||
|
||||
@@ -376,6 +376,10 @@ def is_torch_tpu_available():
|
||||
return importlib.util.find_spec("torch_xla.core.xla_model") is not None
|
||||
|
||||
|
||||
def is_torchdynamo_available():
|
||||
return importlib.util.find_spec("torchdynamo") is not None
|
||||
|
||||
|
||||
def is_datasets_available():
|
||||
return _datasets_available
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ from transformers.testing_utils import (
|
||||
require_torch_non_multi_gpu,
|
||||
require_torch_tf32,
|
||||
require_torch_up_to_2_gpus,
|
||||
require_torchdynamo,
|
||||
require_wandb,
|
||||
slow,
|
||||
)
|
||||
@@ -1594,6 +1595,100 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
# perfect world: fp32_init/2 == fp16_eval
|
||||
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
@require_torchdynamo
|
||||
def test_torchdynamo_full_eval(self):
|
||||
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
|
||||
n_gpus = get_gpu_count()
|
||||
|
||||
bs = 8
|
||||
eval_len = 16 * n_gpus
|
||||
# make the params are somewhat big so that there will be enough RAM consumed to be able to
|
||||
# measure things. We should get about 64KB for a+b in fp32
|
||||
a = torch.ones(1000, bs) + 0.001
|
||||
b = torch.ones(1000, bs) - 0.001
|
||||
|
||||
# 1. Default - without TorchDynamo
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len)
|
||||
metrics = trainer.evaluate()
|
||||
original_eval_loss = metrics["eval_loss"]
|
||||
del trainer
|
||||
|
||||
# 2. TorchDynamo eager
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="eager")
|
||||
metrics = trainer.evaluate()
|
||||
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||
del trainer
|
||||
|
||||
# 3. TorchDynamo nvfuser
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser")
|
||||
metrics = trainer.evaluate()
|
||||
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
@require_torchdynamo
|
||||
def test_torchdynamo_memory(self):
|
||||
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
|
||||
class CustomTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
x = inputs["x"]
|
||||
output = model(x)
|
||||
if self.args.n_gpu == 1:
|
||||
return output.mean()
|
||||
return output
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
"""Simple module that does aggressive fusion"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
for _ in range(20):
|
||||
x = torch.nn.functional.relu(x)
|
||||
return x
|
||||
|
||||
mod = MyModule()
|
||||
|
||||
# 1. Default - without TorchDynamo
|
||||
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
|
||||
a.grad = None
|
||||
trainer = CustomTrainer(model=mod)
|
||||
# warmup
|
||||
for _ in range(10):
|
||||
orig_loss = trainer.training_step(mod, {"x": a})
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
orig_loss = trainer.training_step(mod, {"x": a})
|
||||
orig_peak_mem = torch.cuda.max_memory_allocated()
|
||||
del trainer
|
||||
|
||||
# Reset the peak for another measurement
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# 2. TorchDynamo nvfuser
|
||||
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
|
||||
a.grad = None
|
||||
args = TrainingArguments(output_dir="None", torchdynamo="nvfuser")
|
||||
trainer = CustomTrainer(model=mod, args=args)
|
||||
# warmup
|
||||
for _ in range(10):
|
||||
loss = trainer.training_step(mod, {"x": a})
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
loss = trainer.training_step(mod, {"x": a})
|
||||
peak_mem = torch.cuda.max_memory_allocated()
|
||||
del trainer
|
||||
|
||||
# Functional check
|
||||
self.assertAlmostEqual(loss, orig_loss)
|
||||
|
||||
# AOT Autograd recomputaion and nvfuser recomputation optimization
|
||||
# aggressively fuses the operations and reduce the memory footprint.
|
||||
self.assertGreater(orig_peak_mem, peak_mem * 2)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_bf16
|
||||
def test_bf16_full_eval(self):
|
||||
|
||||
Reference in New Issue
Block a user