From 897a8dd89f40817201bc4aebe532a096405bdeb1 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 25 May 2022 08:16:09 -0700 Subject: [PATCH] Support compilation via Torchdynamo, AOT Autograd, NVFuser (#17308) * Support compilation via Torchdynamo, AOT Autograd, NVFuser * Address comments * Lint * Stas comments - missing quality test * Lintere * Quality test * Doc lint * Reset CUDA peak mem * Add CustomTrainer * require a single gpu Co-authored-by: Stas Bekman --- src/transformers/testing_utils.py | 6 ++ src/transformers/trainer.py | 34 ++++++++- src/transformers/trainer_seq2seq.py | 2 +- src/transformers/training_args.py | 17 +++++ src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 4 ++ tests/trainer/test_trainer.py | 95 ++++++++++++++++++++++++++ 7 files changed, 155 insertions(+), 4 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index fe61303639..58d996deef 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -70,6 +70,7 @@ from .utils import ( is_torch_tf32_available, is_torch_tpu_available, is_torchaudio_available, + is_torchdynamo_available, is_vision_available, ) @@ -464,6 +465,11 @@ else: jax_device = None +def require_torchdynamo(test_case): + """Decorator marking a test that requires TorchDynamo""" + return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case) + + def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 83aec582c4..c64c3a2805 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -139,8 +139,10 @@ from .utils import ( is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available, + is_torchdynamo_available, logging, ) +from .utils.generic import ContextManagers _is_torch_generator_available = False @@ -2172,6 +2174,32 @@ class Trainer: return inputs + def compute_loss_context_manager(self): + """ + A helper wrapper to group together context managers. + """ + return ContextManagers( + [ + self.torchdynamo_smart_context_manager(), + self.autocast_smart_context_manager(), + ] + ) + + def torchdynamo_smart_context_manager(self): + """ + A helper wrapper that creates an appropriate context manager for `torchdynamo`. + """ + ctx_manager = contextlib.nullcontext() + if is_torchdynamo_available(): + import torchdynamo + from torchdynamo.optimizations.training import aot_autograd_speedup_strategy + + if self.args.torchdynamo == "eager": + ctx_manager = torchdynamo.optimize("eager") + elif self.args.torchdynamo == "nvfuser": + ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy) + return ctx_manager + def autocast_smart_context_manager(self): """ A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired @@ -2213,7 +2241,7 @@ class Trainer: loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler) return loss_mb.reduce_mean().detach().to(self.args.device) - with self.autocast_smart_context_manager(): + with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) if self.args.n_gpu > 1: @@ -2907,7 +2935,7 @@ class Trainer: logits = smp_nested_concat(logits_mb) else: if has_labels: - with self.autocast_smart_context_manager(): + with self.compute_loss_context_manager(): loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss = loss.mean().detach() @@ -2917,7 +2945,7 @@ class Trainer: logits = outputs[1:] else: loss = None - with self.autocast_smart_context_manager(): + with self.compute_loss_context_manager(): outputs = model(**inputs) if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 5513b58bef..7a290fe149 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -183,7 +183,7 @@ class Seq2SeqTrainer(Trainer): generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) with torch.no_grad(): - with self.autocast_smart_context_manager(): + with self.compute_loss_context_manager(): outputs = model(**inputs) if has_labels: if self.label_smoother is not None: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 76e0132bcc..f96ecc3ab9 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -450,6 +450,9 @@ class TrainingArguments: full_determinism (`bool`, *optional*, defaults to `False`) If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in distributed training + torchdynamo (`str`, *optional*): + The token that is used to set the backend compiler for TorchDynamo. Possible choices are ["eager", + "nvfuser]. This is an experimental API and subject to change. """ output_dir: str = field( @@ -881,6 +884,20 @@ class TrainingArguments: ) }, ) + torchdynamo: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Sets up the backend compiler for TorchDynamo. TorchDynamo is a Python level JIT compiler designed to" + " make unmodified PyTorch programs faster. TorchDynamo dynamically modifies the Python bytecode right" + " before its executed. It rewrites Python bytecode to extract sequences of PyTorch operations" + " and lifts them up into Fx graph. We can then pass these Fx graphs to other backend compilers. There" + " are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging." + " nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models." + ), + "choices": ["eager", "nvfuser"], + }, + ) def __post_init__(self): # Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then). diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 2106cdb007..36e2fa4330 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -130,6 +130,7 @@ from .import_utils import ( is_torch_tf32_available, is_torch_tpu_available, is_torchaudio_available, + is_torchdynamo_available, is_training_run_on_sagemaker, is_vision_available, requires_backends, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 1c6fda55f5..e5c05a0d82 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -376,6 +376,10 @@ def is_torch_tpu_available(): return importlib.util.find_spec("torch_xla.core.xla_model") is not None +def is_torchdynamo_available(): + return importlib.util.find_spec("torchdynamo") is not None + + def is_datasets_available(): return _datasets_available diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index cb9bde6329..fc089b000c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -62,6 +62,7 @@ from transformers.testing_utils import ( require_torch_non_multi_gpu, require_torch_tf32, require_torch_up_to_2_gpus, + require_torchdynamo, require_wandb, slow, ) @@ -1594,6 +1595,100 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): # perfect world: fp32_init/2 == fp16_eval self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000) + @require_torch_non_multi_gpu + @require_torchdynamo + def test_torchdynamo_full_eval(self): + # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu + n_gpus = get_gpu_count() + + bs = 8 + eval_len = 16 * n_gpus + # make the params are somewhat big so that there will be enough RAM consumed to be able to + # measure things. We should get about 64KB for a+b in fp32 + a = torch.ones(1000, bs) + 0.001 + b = torch.ones(1000, bs) - 0.001 + + # 1. Default - without TorchDynamo + trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len) + metrics = trainer.evaluate() + original_eval_loss = metrics["eval_loss"] + del trainer + + # 2. TorchDynamo eager + trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="eager") + metrics = trainer.evaluate() + self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss) + del trainer + + # 3. TorchDynamo nvfuser + trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser") + metrics = trainer.evaluate() + self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss) + + @require_torch_non_multi_gpu + @require_torchdynamo + def test_torchdynamo_memory(self): + # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu + class CustomTrainer(Trainer): + def compute_loss(self, model, inputs, return_outputs=False): + x = inputs["x"] + output = model(x) + if self.args.n_gpu == 1: + return output.mean() + return output + + class MyModule(torch.nn.Module): + """Simple module that does aggressive fusion""" + + def __init__(self): + super().__init__() + + def forward(self, x): + for _ in range(20): + x = torch.nn.functional.relu(x) + return x + + mod = MyModule() + + # 1. Default - without TorchDynamo + a = torch.ones(1024, 1024, device="cuda", requires_grad=True) + a.grad = None + trainer = CustomTrainer(model=mod) + # warmup + for _ in range(10): + orig_loss = trainer.training_step(mod, {"x": a}) + + torch.cuda.reset_peak_memory_stats() + orig_loss = trainer.training_step(mod, {"x": a}) + orig_peak_mem = torch.cuda.max_memory_allocated() + del trainer + + # Reset the peak for another measurement + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + # 2. TorchDynamo nvfuser + a = torch.ones(1024, 1024, device="cuda", requires_grad=True) + a.grad = None + args = TrainingArguments(output_dir="None", torchdynamo="nvfuser") + trainer = CustomTrainer(model=mod, args=args) + # warmup + for _ in range(10): + loss = trainer.training_step(mod, {"x": a}) + + torch.cuda.reset_peak_memory_stats() + loss = trainer.training_step(mod, {"x": a}) + peak_mem = torch.cuda.max_memory_allocated() + del trainer + + # Functional check + self.assertAlmostEqual(loss, orig_loss) + + # AOT Autograd recomputaion and nvfuser recomputation optimization + # aggressively fuses the operations and reduce the memory footprint. + self.assertGreater(orig_peak_mem, peak_mem * 2) + @require_torch_gpu @require_torch_bf16 def test_bf16_full_eval(self):