enable xpu in test_trainer (#37774)
* enable xpu in test_trainer Signed-off-by: YAO Matrix <matrix.yao@intel.com> * fix style Signed-off-by: YAO Matrix <matrix.yao@intel.com> * enhance _device_agnostic_dispatch to cover value Signed-off-by: Yao Matrix <matrix.yao@intel.com> * add default values for torch not available case Signed-off-by: YAO Matrix <matrix.yao@intel.com> --------- Signed-off-by: YAO Matrix <matrix.yao@intel.com> Signed-off-by: Yao Matrix <matrix.yao@intel.com>
This commit is contained in:
@@ -2946,10 +2946,10 @@ def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable],
|
|||||||
|
|
||||||
fn = dispatch_table[device]
|
fn = dispatch_table[device]
|
||||||
|
|
||||||
# Some device agnostic functions return values. Need to guard against `None`
|
# Some device agnostic functions return values or None, will return then directly.
|
||||||
# instead at user level.
|
if not callable(fn):
|
||||||
if fn is None:
|
return fn
|
||||||
return None
|
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@@ -2971,10 +2971,29 @@ if is_torch_available():
|
|||||||
"cpu": lambda: 0,
|
"cpu": lambda: 0,
|
||||||
"default": lambda: 1,
|
"default": lambda: 1,
|
||||||
}
|
}
|
||||||
|
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
|
||||||
|
"cuda": torch.cuda.reset_max_memory_allocated,
|
||||||
|
"cpu": None,
|
||||||
|
"default": None,
|
||||||
|
}
|
||||||
|
BACKEND_MAX_MEMORY_ALLOCATED = {
|
||||||
|
"cuda": torch.cuda.max_memory_allocated,
|
||||||
|
"cpu": 0,
|
||||||
|
"default": 0,
|
||||||
|
}
|
||||||
|
BACKEND_MEMORY_ALLOCATED = {
|
||||||
|
"cuda": torch.cuda.memory_allocated,
|
||||||
|
"cpu": 0,
|
||||||
|
"default": 0,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
BACKEND_MANUAL_SEED = {"default": None}
|
BACKEND_MANUAL_SEED = {"default": None}
|
||||||
BACKEND_EMPTY_CACHE = {"default": None}
|
BACKEND_EMPTY_CACHE = {"default": None}
|
||||||
BACKEND_DEVICE_COUNT = {"default": lambda: 0}
|
BACKEND_DEVICE_COUNT = {"default": lambda: 0}
|
||||||
|
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None}
|
||||||
|
BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0}
|
||||||
|
BACKEND_MEMORY_ALLOCATED = {"default": 0}
|
||||||
|
|
||||||
|
|
||||||
if is_torch_hpu_available():
|
if is_torch_hpu_available():
|
||||||
BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed
|
BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed
|
||||||
@@ -2994,6 +3013,9 @@ if is_torch_xpu_available():
|
|||||||
BACKEND_EMPTY_CACHE["xpu"] = torch.xpu.empty_cache
|
BACKEND_EMPTY_CACHE["xpu"] = torch.xpu.empty_cache
|
||||||
BACKEND_MANUAL_SEED["xpu"] = torch.xpu.manual_seed
|
BACKEND_MANUAL_SEED["xpu"] = torch.xpu.manual_seed
|
||||||
BACKEND_DEVICE_COUNT["xpu"] = torch.xpu.device_count
|
BACKEND_DEVICE_COUNT["xpu"] = torch.xpu.device_count
|
||||||
|
BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats
|
||||||
|
BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated
|
||||||
|
BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated
|
||||||
|
|
||||||
if is_torch_xla_available():
|
if is_torch_xla_available():
|
||||||
BACKEND_EMPTY_CACHE["xla"] = torch.cuda.empty_cache
|
BACKEND_EMPTY_CACHE["xla"] = torch.cuda.empty_cache
|
||||||
@@ -3013,6 +3035,18 @@ def backend_device_count(device: str):
|
|||||||
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
|
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
|
||||||
|
|
||||||
|
|
||||||
|
def backend_reset_max_memory_allocated(device: str):
|
||||||
|
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
|
||||||
|
|
||||||
|
|
||||||
|
def backend_max_memory_allocated(device: str):
|
||||||
|
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
|
||||||
|
|
||||||
|
|
||||||
|
def backend_memory_allocated(device: str):
|
||||||
|
return _device_agnostic_dispatch(device, BACKEND_MEMORY_ALLOCATED)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
# If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries
|
# If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries
|
||||||
# into device to function mappings.
|
# into device to function mappings.
|
||||||
|
|||||||
@@ -62,6 +62,10 @@ from transformers.testing_utils import (
|
|||||||
TemporaryHubRepo,
|
TemporaryHubRepo,
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
backend_device_count,
|
backend_device_count,
|
||||||
|
backend_empty_cache,
|
||||||
|
backend_max_memory_allocated,
|
||||||
|
backend_memory_allocated,
|
||||||
|
backend_reset_max_memory_allocated,
|
||||||
evaluate_side_effect_factory,
|
evaluate_side_effect_factory,
|
||||||
execute_subprocess_async,
|
execute_subprocess_async,
|
||||||
get_gpu_count,
|
get_gpu_count,
|
||||||
@@ -78,7 +82,6 @@ from transformers.testing_utils import (
|
|||||||
require_liger_kernel,
|
require_liger_kernel,
|
||||||
require_lomo,
|
require_lomo,
|
||||||
require_non_hpu,
|
require_non_hpu,
|
||||||
require_non_xpu,
|
|
||||||
require_optuna,
|
require_optuna,
|
||||||
require_peft,
|
require_peft,
|
||||||
require_ray,
|
require_ray,
|
||||||
@@ -245,18 +248,18 @@ def bytes2megabytes(x):
|
|||||||
class TorchTracemalloc:
|
class TorchTracemalloc:
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
if torch_device in ["cuda", "xpu"]:
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
|
backend_reset_max_memory_allocated(torch_device) # reset the peak gauge to zero
|
||||||
self.begin = torch.cuda.memory_allocated()
|
self.begin = backend_memory_allocated(torch_device)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, *exc):
|
def __exit__(self, *exc):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
if torch_device in ["cuda", "xpu"]:
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
self.end = torch.cuda.memory_allocated()
|
self.end = backend_memory_allocated(torch_device)
|
||||||
self.peak = torch.cuda.max_memory_allocated()
|
self.peak = backend_max_memory_allocated(torch_device)
|
||||||
self.used = bytes2megabytes(self.end - self.begin)
|
self.used = bytes2megabytes(self.end - self.begin)
|
||||||
self.peaked = bytes2megabytes(self.peak - self.begin)
|
self.peaked = bytes2megabytes(self.peak - self.begin)
|
||||||
|
|
||||||
@@ -1246,7 +1249,6 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
# will add more specific tests once there are some bugs to fix
|
# will add more specific tests once there are some bugs to fix
|
||||||
|
|
||||||
@require_non_xpu
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_torch_tf32
|
@require_torch_tf32
|
||||||
def test_tf32(self):
|
def test_tf32(self):
|
||||||
@@ -1838,7 +1840,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
_ = trainer.train()
|
_ = trainer.train()
|
||||||
|
|
||||||
@require_lomo
|
@require_lomo
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_lomo(self):
|
def test_lomo(self):
|
||||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
tiny_llama = LlamaForCausalLM(config)
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
@@ -1861,7 +1863,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertFalse(torch.allclose(param, previous_params[name].to(param.device), rtol=1e-12, atol=1e-12))
|
self.assertFalse(torch.allclose(param, previous_params[name].to(param.device), rtol=1e-12, atol=1e-12))
|
||||||
|
|
||||||
@require_lomo
|
@require_lomo
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_adalomo(self):
|
def test_adalomo(self):
|
||||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
tiny_llama = LlamaForCausalLM(config)
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
@@ -2027,7 +2029,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertFalse(is_regex)
|
self.assertFalse(is_regex)
|
||||||
|
|
||||||
@require_galore_torch
|
@require_galore_torch
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_galore(self):
|
def test_galore(self):
|
||||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
tiny_llama = LlamaForCausalLM(config)
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
@@ -2048,7 +2050,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
_ = trainer.train()
|
_ = trainer.train()
|
||||||
|
|
||||||
@require_galore_torch
|
@require_galore_torch
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_galore_extra_args(self):
|
def test_galore_extra_args(self):
|
||||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
tiny_llama = LlamaForCausalLM(config)
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
@@ -2070,7 +2072,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
_ = trainer.train()
|
_ = trainer.train()
|
||||||
|
|
||||||
@require_galore_torch
|
@require_galore_torch
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_galore_layerwise(self):
|
def test_galore_layerwise(self):
|
||||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
tiny_llama = LlamaForCausalLM(config)
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
@@ -2091,7 +2093,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
_ = trainer.train()
|
_ = trainer.train()
|
||||||
|
|
||||||
@require_galore_torch
|
@require_galore_torch
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_galore_layerwise_with_scheduler(self):
|
def test_galore_layerwise_with_scheduler(self):
|
||||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
tiny_llama = LlamaForCausalLM(config)
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
@@ -2113,7 +2115,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
_ = trainer.train()
|
_ = trainer.train()
|
||||||
|
|
||||||
@require_galore_torch
|
@require_galore_torch
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_galore_adamw_8bit(self):
|
def test_galore_adamw_8bit(self):
|
||||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
tiny_llama = LlamaForCausalLM(config)
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
@@ -2134,7 +2136,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
_ = trainer.train()
|
_ = trainer.train()
|
||||||
|
|
||||||
@require_galore_torch
|
@require_galore_torch
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_galore_adafactor(self):
|
def test_galore_adafactor(self):
|
||||||
# These are the intervals of the peak memory usage of training such a tiny model
|
# These are the intervals of the peak memory usage of training such a tiny model
|
||||||
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
||||||
@@ -2166,7 +2168,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
||||||
|
|
||||||
@require_galore_torch
|
@require_galore_torch
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_galore_adafactor_attention_only(self):
|
def test_galore_adafactor_attention_only(self):
|
||||||
# These are the intervals of the peak memory usage of training such a tiny model
|
# These are the intervals of the peak memory usage of training such a tiny model
|
||||||
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
||||||
@@ -2197,7 +2199,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
||||||
|
|
||||||
@require_galore_torch
|
@require_galore_torch
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_galore_adafactor_all_linear(self):
|
def test_galore_adafactor_all_linear(self):
|
||||||
# These are the intervals of the peak memory usage of training such a tiny model
|
# These are the intervals of the peak memory usage of training such a tiny model
|
||||||
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
||||||
@@ -2305,7 +2307,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
|
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
|
||||||
|
|
||||||
@require_apollo_torch
|
@require_apollo_torch
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_apollo(self):
|
def test_apollo(self):
|
||||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
tiny_llama = LlamaForCausalLM(config)
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
@@ -2326,7 +2328,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
_ = trainer.train()
|
_ = trainer.train()
|
||||||
|
|
||||||
@require_apollo_torch
|
@require_apollo_torch
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_apollo_extra_args(self):
|
def test_apollo_extra_args(self):
|
||||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
tiny_llama = LlamaForCausalLM(config)
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
@@ -2348,7 +2350,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
_ = trainer.train()
|
_ = trainer.train()
|
||||||
|
|
||||||
@require_apollo_torch
|
@require_apollo_torch
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_apollo_layerwise(self):
|
def test_apollo_layerwise(self):
|
||||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
tiny_llama = LlamaForCausalLM(config)
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
@@ -2369,7 +2371,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
_ = trainer.train()
|
_ = trainer.train()
|
||||||
|
|
||||||
@require_apollo_torch
|
@require_apollo_torch
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_apollo_layerwise_with_scheduler(self):
|
def test_apollo_layerwise_with_scheduler(self):
|
||||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
tiny_llama = LlamaForCausalLM(config)
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
@@ -2391,7 +2393,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
_ = trainer.train()
|
_ = trainer.train()
|
||||||
|
|
||||||
@require_apollo_torch
|
@require_apollo_torch
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_apollo_lr_display_without_scheduler(self):
|
def test_apollo_lr_display_without_scheduler(self):
|
||||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
tiny_llama = LlamaForCausalLM(config)
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
@@ -2416,7 +2418,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])
|
self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])
|
||||||
|
|
||||||
@require_apollo_torch
|
@require_apollo_torch
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_apollo_lr_display_with_scheduler(self):
|
def test_apollo_lr_display_with_scheduler(self):
|
||||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
tiny_llama = LlamaForCausalLM(config)
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
@@ -3995,7 +3997,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
# perfect world: fp32_init/2 == fp16_eval
|
# perfect world: fp32_init/2 == fp16_eval
|
||||||
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
|
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
|
||||||
|
|
||||||
@require_non_xpu
|
@require_torch_gpu
|
||||||
@require_torch_non_multi_gpu
|
@require_torch_non_multi_gpu
|
||||||
@require_torch_tensorrt_fx
|
@require_torch_tensorrt_fx
|
||||||
def test_torchdynamo_full_eval(self):
|
def test_torchdynamo_full_eval(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user