Support compilation via Torchdynamo, AOT Autograd, NVFuser (#17308)
* Support compilation via Torchdynamo, AOT Autograd, NVFuser * Address comments * Lint * Stas comments - missing quality test * Lintere * Quality test * Doc lint * Reset CUDA peak mem * Add CustomTrainer * require a single gpu Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
@@ -70,6 +70,7 @@ from .utils import (
|
|||||||
is_torch_tf32_available,
|
is_torch_tf32_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
|
is_torchdynamo_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -464,6 +465,11 @@ else:
|
|||||||
jax_device = None
|
jax_device = None
|
||||||
|
|
||||||
|
|
||||||
|
def require_torchdynamo(test_case):
|
||||||
|
"""Decorator marking a test that requires TorchDynamo"""
|
||||||
|
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_torch_gpu(test_case):
|
def require_torch_gpu(test_case):
|
||||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||||
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
||||||
|
|||||||
@@ -139,8 +139,10 @@ from .utils import (
|
|||||||
is_sagemaker_dp_enabled,
|
is_sagemaker_dp_enabled,
|
||||||
is_sagemaker_mp_enabled,
|
is_sagemaker_mp_enabled,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
|
is_torchdynamo_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
from .utils.generic import ContextManagers
|
||||||
|
|
||||||
|
|
||||||
_is_torch_generator_available = False
|
_is_torch_generator_available = False
|
||||||
@@ -2172,6 +2174,32 @@ class Trainer:
|
|||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
def compute_loss_context_manager(self):
|
||||||
|
"""
|
||||||
|
A helper wrapper to group together context managers.
|
||||||
|
"""
|
||||||
|
return ContextManagers(
|
||||||
|
[
|
||||||
|
self.torchdynamo_smart_context_manager(),
|
||||||
|
self.autocast_smart_context_manager(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def torchdynamo_smart_context_manager(self):
|
||||||
|
"""
|
||||||
|
A helper wrapper that creates an appropriate context manager for `torchdynamo`.
|
||||||
|
"""
|
||||||
|
ctx_manager = contextlib.nullcontext()
|
||||||
|
if is_torchdynamo_available():
|
||||||
|
import torchdynamo
|
||||||
|
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
|
||||||
|
|
||||||
|
if self.args.torchdynamo == "eager":
|
||||||
|
ctx_manager = torchdynamo.optimize("eager")
|
||||||
|
elif self.args.torchdynamo == "nvfuser":
|
||||||
|
ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
|
||||||
|
return ctx_manager
|
||||||
|
|
||||||
def autocast_smart_context_manager(self):
|
def autocast_smart_context_manager(self):
|
||||||
"""
|
"""
|
||||||
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
|
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
|
||||||
@@ -2213,7 +2241,7 @@ class Trainer:
|
|||||||
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
|
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
|
||||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||||
|
|
||||||
with self.autocast_smart_context_manager():
|
with self.compute_loss_context_manager():
|
||||||
loss = self.compute_loss(model, inputs)
|
loss = self.compute_loss(model, inputs)
|
||||||
|
|
||||||
if self.args.n_gpu > 1:
|
if self.args.n_gpu > 1:
|
||||||
@@ -2907,7 +2935,7 @@ class Trainer:
|
|||||||
logits = smp_nested_concat(logits_mb)
|
logits = smp_nested_concat(logits_mb)
|
||||||
else:
|
else:
|
||||||
if has_labels:
|
if has_labels:
|
||||||
with self.autocast_smart_context_manager():
|
with self.compute_loss_context_manager():
|
||||||
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
||||||
loss = loss.mean().detach()
|
loss = loss.mean().detach()
|
||||||
|
|
||||||
@@ -2917,7 +2945,7 @@ class Trainer:
|
|||||||
logits = outputs[1:]
|
logits = outputs[1:]
|
||||||
else:
|
else:
|
||||||
loss = None
|
loss = None
|
||||||
with self.autocast_smart_context_manager():
|
with self.compute_loss_context_manager():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
if isinstance(outputs, dict):
|
if isinstance(outputs, dict):
|
||||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
|
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
|
||||||
|
|||||||
@@ -183,7 +183,7 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with self.autocast_smart_context_manager():
|
with self.compute_loss_context_manager():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
if has_labels:
|
if has_labels:
|
||||||
if self.label_smoother is not None:
|
if self.label_smoother is not None:
|
||||||
|
|||||||
@@ -450,6 +450,9 @@ class TrainingArguments:
|
|||||||
full_determinism (`bool`, *optional*, defaults to `False`)
|
full_determinism (`bool`, *optional*, defaults to `False`)
|
||||||
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
|
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
|
||||||
distributed training
|
distributed training
|
||||||
|
torchdynamo (`str`, *optional*):
|
||||||
|
The token that is used to set the backend compiler for TorchDynamo. Possible choices are ["eager",
|
||||||
|
"nvfuser]. This is an experimental API and subject to change.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@@ -881,6 +884,20 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
torchdynamo: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Sets up the backend compiler for TorchDynamo. TorchDynamo is a Python level JIT compiler designed to"
|
||||||
|
" make unmodified PyTorch programs faster. TorchDynamo dynamically modifies the Python bytecode right"
|
||||||
|
" before its executed. It rewrites Python bytecode to extract sequences of PyTorch operations"
|
||||||
|
" and lifts them up into Fx graph. We can then pass these Fx graphs to other backend compilers. There"
|
||||||
|
" are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging."
|
||||||
|
" nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models."
|
||||||
|
),
|
||||||
|
"choices": ["eager", "nvfuser"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
|
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
|
||||||
|
|||||||
@@ -130,6 +130,7 @@ from .import_utils import (
|
|||||||
is_torch_tf32_available,
|
is_torch_tf32_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
|
is_torchdynamo_available,
|
||||||
is_training_run_on_sagemaker,
|
is_training_run_on_sagemaker,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
|
|||||||
@@ -376,6 +376,10 @@ def is_torch_tpu_available():
|
|||||||
return importlib.util.find_spec("torch_xla.core.xla_model") is not None
|
return importlib.util.find_spec("torch_xla.core.xla_model") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_torchdynamo_available():
|
||||||
|
return importlib.util.find_spec("torchdynamo") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_datasets_available():
|
def is_datasets_available():
|
||||||
return _datasets_available
|
return _datasets_available
|
||||||
|
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ from transformers.testing_utils import (
|
|||||||
require_torch_non_multi_gpu,
|
require_torch_non_multi_gpu,
|
||||||
require_torch_tf32,
|
require_torch_tf32,
|
||||||
require_torch_up_to_2_gpus,
|
require_torch_up_to_2_gpus,
|
||||||
|
require_torchdynamo,
|
||||||
require_wandb,
|
require_wandb,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
@@ -1594,6 +1595,100 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
# perfect world: fp32_init/2 == fp16_eval
|
# perfect world: fp32_init/2 == fp16_eval
|
||||||
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
|
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
|
||||||
|
|
||||||
|
@require_torch_non_multi_gpu
|
||||||
|
@require_torchdynamo
|
||||||
|
def test_torchdynamo_full_eval(self):
|
||||||
|
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
|
||||||
|
n_gpus = get_gpu_count()
|
||||||
|
|
||||||
|
bs = 8
|
||||||
|
eval_len = 16 * n_gpus
|
||||||
|
# make the params are somewhat big so that there will be enough RAM consumed to be able to
|
||||||
|
# measure things. We should get about 64KB for a+b in fp32
|
||||||
|
a = torch.ones(1000, bs) + 0.001
|
||||||
|
b = torch.ones(1000, bs) - 0.001
|
||||||
|
|
||||||
|
# 1. Default - without TorchDynamo
|
||||||
|
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len)
|
||||||
|
metrics = trainer.evaluate()
|
||||||
|
original_eval_loss = metrics["eval_loss"]
|
||||||
|
del trainer
|
||||||
|
|
||||||
|
# 2. TorchDynamo eager
|
||||||
|
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="eager")
|
||||||
|
metrics = trainer.evaluate()
|
||||||
|
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||||
|
del trainer
|
||||||
|
|
||||||
|
# 3. TorchDynamo nvfuser
|
||||||
|
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser")
|
||||||
|
metrics = trainer.evaluate()
|
||||||
|
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||||
|
|
||||||
|
@require_torch_non_multi_gpu
|
||||||
|
@require_torchdynamo
|
||||||
|
def test_torchdynamo_memory(self):
|
||||||
|
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
|
||||||
|
class CustomTrainer(Trainer):
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
|
x = inputs["x"]
|
||||||
|
output = model(x)
|
||||||
|
if self.args.n_gpu == 1:
|
||||||
|
return output.mean()
|
||||||
|
return output
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
"""Simple module that does aggressive fusion"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for _ in range(20):
|
||||||
|
x = torch.nn.functional.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
mod = MyModule()
|
||||||
|
|
||||||
|
# 1. Default - without TorchDynamo
|
||||||
|
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
|
||||||
|
a.grad = None
|
||||||
|
trainer = CustomTrainer(model=mod)
|
||||||
|
# warmup
|
||||||
|
for _ in range(10):
|
||||||
|
orig_loss = trainer.training_step(mod, {"x": a})
|
||||||
|
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
orig_loss = trainer.training_step(mod, {"x": a})
|
||||||
|
orig_peak_mem = torch.cuda.max_memory_allocated()
|
||||||
|
del trainer
|
||||||
|
|
||||||
|
# Reset the peak for another measurement
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
# 2. TorchDynamo nvfuser
|
||||||
|
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
|
||||||
|
a.grad = None
|
||||||
|
args = TrainingArguments(output_dir="None", torchdynamo="nvfuser")
|
||||||
|
trainer = CustomTrainer(model=mod, args=args)
|
||||||
|
# warmup
|
||||||
|
for _ in range(10):
|
||||||
|
loss = trainer.training_step(mod, {"x": a})
|
||||||
|
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
loss = trainer.training_step(mod, {"x": a})
|
||||||
|
peak_mem = torch.cuda.max_memory_allocated()
|
||||||
|
del trainer
|
||||||
|
|
||||||
|
# Functional check
|
||||||
|
self.assertAlmostEqual(loss, orig_loss)
|
||||||
|
|
||||||
|
# AOT Autograd recomputaion and nvfuser recomputation optimization
|
||||||
|
# aggressively fuses the operations and reduce the memory footprint.
|
||||||
|
self.assertGreater(orig_peak_mem, peak_mem * 2)
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_torch_bf16
|
@require_torch_bf16
|
||||||
def test_bf16_full_eval(self):
|
def test_bf16_full_eval(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user