enable misc cases on XPU & use device agnostic APIs for cases in tests (#38192)
* use device agnostic APIs in tests Signed-off-by: Matrix Yao <matrix.yao@intel.com> * more Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * add reset_peak_memory_stats API Signed-off-by: YAO Matrix <matrix.yao@intel.com> * update --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com> Signed-off-by: YAO Matrix <matrix.yao@intel.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -66,6 +66,7 @@ from transformers.testing_utils import (
|
||||
backend_max_memory_allocated,
|
||||
backend_memory_allocated,
|
||||
backend_reset_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
evaluate_side_effect_factory,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
@@ -1654,7 +1655,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertFalse(is_any_loss_nan_or_inf(log_history_filter))
|
||||
|
||||
def test_train_and_eval_dataloaders(self):
|
||||
if torch_device == "cuda":
|
||||
if torch_device in ["cuda", "xpu"]:
|
||||
n_gpu = max(1, backend_device_count(torch_device))
|
||||
else:
|
||||
n_gpu = 1
|
||||
@@ -4106,7 +4107,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
mod = MyModule()
|
||||
|
||||
# 1. without TorchDynamo (eager baseline)
|
||||
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
|
||||
a = torch.ones(1024, 1024, device=torch_device, requires_grad=True)
|
||||
a.grad = None
|
||||
trainer = CustomTrainer(model=mod)
|
||||
# warmup
|
||||
@@ -4115,17 +4116,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
# resets
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
|
||||
orig_loss = trainer.training_step(mod, {"x": a})
|
||||
orig_peak_mem = torch.cuda.max_memory_allocated()
|
||||
orig_peak_mem = backend_max_memory_allocated(torch_device)
|
||||
torchdynamo.reset()
|
||||
del trainer
|
||||
|
||||
# 2. TorchDynamo nvfuser
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
|
||||
a = torch.ones(1024, 1024, device=torch_device, requires_grad=True)
|
||||
a.grad = None
|
||||
args = TrainingArguments(output_dir=tmp_dir, torch_compile_backend="nvfuser")
|
||||
trainer = CustomTrainer(model=mod, args=args)
|
||||
@@ -4135,11 +4136,11 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
# resets
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
|
||||
loss = trainer.training_step(mod, {"x": a})
|
||||
peak_mem = torch.cuda.max_memory_allocated()
|
||||
peak_mem = backend_max_memory_allocated(torch_device)
|
||||
torchdynamo.reset()
|
||||
del trainer
|
||||
|
||||
|
||||
Reference in New Issue
Block a user