enable misc cases on XPU & use device agnostic APIs for cases in tests (#38192)
* use device agnostic APIs in tests Signed-off-by: Matrix Yao <matrix.yao@intel.com> * more Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * add reset_peak_memory_stats API Signed-off-by: YAO Matrix <matrix.yao@intel.com> * update --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com> Signed-off-by: YAO Matrix <matrix.yao@intel.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -3024,6 +3024,11 @@ if is_torch_available():
|
|||||||
"cpu": 0,
|
"cpu": 0,
|
||||||
"default": 0,
|
"default": 0,
|
||||||
}
|
}
|
||||||
|
BACKEND_RESET_PEAK_MEMORY_STATS = {
|
||||||
|
"cuda": torch.cuda.reset_peak_memory_stats,
|
||||||
|
"cpu": None,
|
||||||
|
"default": None,
|
||||||
|
}
|
||||||
BACKEND_MEMORY_ALLOCATED = {
|
BACKEND_MEMORY_ALLOCATED = {
|
||||||
"cuda": torch.cuda.memory_allocated,
|
"cuda": torch.cuda.memory_allocated,
|
||||||
"cpu": 0,
|
"cpu": 0,
|
||||||
@@ -3044,6 +3049,7 @@ else:
|
|||||||
BACKEND_EMPTY_CACHE = {"default": None}
|
BACKEND_EMPTY_CACHE = {"default": None}
|
||||||
BACKEND_DEVICE_COUNT = {"default": lambda: 0}
|
BACKEND_DEVICE_COUNT = {"default": lambda: 0}
|
||||||
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None}
|
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None}
|
||||||
|
BACKEND_RESET_PEAK_MEMORY_STATS = {"default": None}
|
||||||
BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0}
|
BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0}
|
||||||
BACKEND_MEMORY_ALLOCATED = {"default": 0}
|
BACKEND_MEMORY_ALLOCATED = {"default": 0}
|
||||||
BACKEND_SYNCHRONIZE = {"default": None}
|
BACKEND_SYNCHRONIZE = {"default": None}
|
||||||
@@ -3072,6 +3078,7 @@ if is_torch_xpu_available():
|
|||||||
BACKEND_MANUAL_SEED["xpu"] = torch.xpu.manual_seed
|
BACKEND_MANUAL_SEED["xpu"] = torch.xpu.manual_seed
|
||||||
BACKEND_DEVICE_COUNT["xpu"] = torch.xpu.device_count
|
BACKEND_DEVICE_COUNT["xpu"] = torch.xpu.device_count
|
||||||
BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats
|
BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats
|
||||||
|
BACKEND_RESET_PEAK_MEMORY_STATS["xpu"] = torch.xpu.reset_peak_memory_stats
|
||||||
BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated
|
BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated
|
||||||
BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated
|
BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated
|
||||||
BACKEND_SYNCHRONIZE["xpu"] = torch.xpu.synchronize
|
BACKEND_SYNCHRONIZE["xpu"] = torch.xpu.synchronize
|
||||||
@@ -3100,6 +3107,10 @@ def backend_reset_max_memory_allocated(device: str):
|
|||||||
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
|
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
|
||||||
|
|
||||||
|
|
||||||
|
def backend_reset_peak_memory_stats(device: str):
|
||||||
|
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
|
||||||
|
|
||||||
|
|
||||||
def backend_max_memory_allocated(device: str):
|
def backend_max_memory_allocated(device: str):
|
||||||
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
|
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.models.idefics3 import Idefics3VisionConfig
|
from transformers.models.idefics3 import Idefics3VisionConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_large_accelerator,
|
require_torch_large_accelerator,
|
||||||
@@ -302,7 +303,7 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_large_accelerator
|
@require_torch_large_accelerator
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ import unittest
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
from transformers import CohereTokenizerFast
|
from transformers import CohereTokenizerFast
|
||||||
from transformers.testing_utils import require_jinja, require_tokenizers, require_torch_multi_gpu
|
from transformers.testing_utils import (
|
||||||
|
require_jinja,
|
||||||
|
require_tokenizers,
|
||||||
|
require_torch_multi_accelerator,
|
||||||
|
)
|
||||||
|
|
||||||
from ...test_tokenization_common import TokenizerTesterMixin, use_cache_if_possible
|
from ...test_tokenization_common import TokenizerTesterMixin, use_cache_if_possible
|
||||||
|
|
||||||
@@ -55,7 +59,7 @@ class CohereTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
return CohereTokenizerFast.from_pretrained(pretrained_name, **kwargs)
|
return CohereTokenizerFast.from_pretrained(pretrained_name, **kwargs)
|
||||||
|
|
||||||
# This gives CPU OOM on a single-gpu runner (~60G RAM). On multi-gpu runner, it has ~180G RAM which is enough.
|
# This gives CPU OOM on a single-gpu runner (~60G RAM). On multi-gpu runner, it has ~180G RAM which is enough.
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_accelerator
|
||||||
def test_torch_encode_plus_sent_to_model(self):
|
def test_torch_encode_plus_sent_to_model(self):
|
||||||
super().test_torch_encode_plus_sent_to_model()
|
super().test_torch_encode_plus_sent_to_model()
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from transformers.models.colpali.configuration_colpali import ColPaliConfig
|
|||||||
from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliForRetrievalOutput
|
from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliForRetrievalOutput
|
||||||
from transformers.models.colpali.processing_colpali import ColPaliProcessor
|
from transformers.models.colpali.processing_colpali import ColPaliProcessor
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_vision,
|
require_vision,
|
||||||
slow,
|
slow,
|
||||||
@@ -303,7 +304,7 @@ class ColPaliModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_integration_test(self):
|
def test_model_integration_test(self):
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from transformers.testing_utils import (
|
|||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_accelerator,
|
||||||
require_torch_sdpa,
|
require_torch_sdpa,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
@@ -583,7 +583,7 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
cleanup(torch_device, gc_collect=True)
|
cleanup(torch_device, gc_collect=True)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_accelerator
|
||||||
def test_integration_test(self):
|
def test_integration_test(self):
|
||||||
model = Idefics2ForConditionalGeneration.from_pretrained(
|
model = Idefics2ForConditionalGeneration.from_pretrained(
|
||||||
"HuggingFaceM4/idefics2-8b-base",
|
"HuggingFaceM4/idefics2-8b-base",
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from transformers import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
|
from transformers.testing_utils import backend_empty_cache, require_soundfile, require_torch, slow, torch_device
|
||||||
from transformers.utils import is_soundfile_available
|
from transformers.utils import is_soundfile_available
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
@@ -296,7 +296,7 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
def test_text_only_generation(self):
|
def test_text_only_generation(self):
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from transformers import (
|
|||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
is_flaky,
|
is_flaky,
|
||||||
require_cv2,
|
require_cv2,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
@@ -421,7 +422,7 @@ class Qwen2_5_VLIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_small_model_integration_test(self):
|
def test_small_model_integration_test(self):
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from transformers import (
|
|||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
@@ -367,7 +368,7 @@ class Qwen2VLIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_small_model_integration_test(self):
|
def test_small_model_integration_test(self):
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ from transformers import WhisperConfig
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_flaky,
|
is_flaky,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_non_xpu,
|
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_fp16,
|
require_torch_fp16,
|
||||||
@@ -42,7 +41,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import cached_property, is_torch_available, is_torchaudio_available
|
from transformers.utils import cached_property, is_torch_available, is_torch_xpu_available, is_torchaudio_available
|
||||||
from transformers.utils.import_utils import is_datasets_available
|
from transformers.utils.import_utils import is_datasets_available
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
@@ -2431,11 +2430,10 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
" How many different species are there in the chilli? How many different species are there in the chilli?",
|
" How many different species are there in the chilli? How many different species are there in the chilli?",
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_non_xpu
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_speculative_decoding_distil(self):
|
def test_speculative_decoding_distil(self):
|
||||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
torch_dtype = torch.float16 if (torch.cuda.is_available() or is_torch_xpu_available()) else torch.float32
|
||||||
model_id = "openai/whisper-large-v2"
|
model_id = "openai/whisper-large-v2"
|
||||||
model = WhisperForConditionalGeneration.from_pretrained(
|
model = WhisperForConditionalGeneration.from_pretrained(
|
||||||
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
||||||
|
|||||||
@@ -21,9 +21,11 @@ from transformers import is_torch_available
|
|||||||
from transformers.integrations.tensor_parallel import get_packed_weights, repack_weights
|
from transformers.integrations.tensor_parallel import get_packed_weights, repack_weights
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
|
backend_device_count,
|
||||||
get_torch_dist_unique_port,
|
get_torch_dist_unique_port,
|
||||||
require_huggingface_hub_greater_or_equal,
|
require_huggingface_hub_greater_or_equal,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -168,4 +170,4 @@ class TestTensorParallel(TestCasePlus):
|
|||||||
|
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
class TestTensorParallelCuda(TestTensorParallel):
|
class TestTensorParallelCuda(TestTensorParallel):
|
||||||
nproc_per_node = torch.cuda.device_count()
|
nproc_per_node = backend_device_count(torch_device)
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ from transformers.models.auto.modeling_auto import (
|
|||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
|
backend_empty_cache,
|
||||||
get_device_properties,
|
get_device_properties,
|
||||||
hub_retry,
|
hub_retry,
|
||||||
is_flaky,
|
is_flaky,
|
||||||
@@ -2652,7 +2653,7 @@ class ModelTesterMixin:
|
|||||||
config = self.model_tester.get_large_model_config()
|
config = self.model_tester.get_large_model_config()
|
||||||
|
|
||||||
for model_class in self.all_parallelizable_model_classes:
|
for model_class in self.all_parallelizable_model_classes:
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
# 1. single gpu memory load + unload + memory measurements
|
# 1. single gpu memory load + unload + memory measurements
|
||||||
# Retrieve initial memory usage (can easily be ~0.6-1.5GB if cuda-kernels have been preloaded by previous tests)
|
# Retrieve initial memory usage (can easily be ~0.6-1.5GB if cuda-kernels have been preloaded by previous tests)
|
||||||
@@ -2668,7 +2669,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
del model
|
del model
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
# 2. MP test
|
# 2. MP test
|
||||||
# it's essential to re-calibrate the usage before the next stage
|
# it's essential to re-calibrate the usage before the next stage
|
||||||
@@ -2692,7 +2693,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
del model
|
del model
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ from transformers.testing_utils import (
|
|||||||
backend_max_memory_allocated,
|
backend_max_memory_allocated,
|
||||||
backend_memory_allocated,
|
backend_memory_allocated,
|
||||||
backend_reset_max_memory_allocated,
|
backend_reset_max_memory_allocated,
|
||||||
|
backend_reset_peak_memory_stats,
|
||||||
evaluate_side_effect_factory,
|
evaluate_side_effect_factory,
|
||||||
execute_subprocess_async,
|
execute_subprocess_async,
|
||||||
get_gpu_count,
|
get_gpu_count,
|
||||||
@@ -1654,7 +1655,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertFalse(is_any_loss_nan_or_inf(log_history_filter))
|
self.assertFalse(is_any_loss_nan_or_inf(log_history_filter))
|
||||||
|
|
||||||
def test_train_and_eval_dataloaders(self):
|
def test_train_and_eval_dataloaders(self):
|
||||||
if torch_device == "cuda":
|
if torch_device in ["cuda", "xpu"]:
|
||||||
n_gpu = max(1, backend_device_count(torch_device))
|
n_gpu = max(1, backend_device_count(torch_device))
|
||||||
else:
|
else:
|
||||||
n_gpu = 1
|
n_gpu = 1
|
||||||
@@ -4106,7 +4107,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
mod = MyModule()
|
mod = MyModule()
|
||||||
|
|
||||||
# 1. without TorchDynamo (eager baseline)
|
# 1. without TorchDynamo (eager baseline)
|
||||||
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
|
a = torch.ones(1024, 1024, device=torch_device, requires_grad=True)
|
||||||
a.grad = None
|
a.grad = None
|
||||||
trainer = CustomTrainer(model=mod)
|
trainer = CustomTrainer(model=mod)
|
||||||
# warmup
|
# warmup
|
||||||
@@ -4115,17 +4116,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
# resets
|
# resets
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
torch.cuda.reset_peak_memory_stats()
|
backend_reset_peak_memory_stats(torch_device)
|
||||||
|
|
||||||
orig_loss = trainer.training_step(mod, {"x": a})
|
orig_loss = trainer.training_step(mod, {"x": a})
|
||||||
orig_peak_mem = torch.cuda.max_memory_allocated()
|
orig_peak_mem = backend_max_memory_allocated(torch_device)
|
||||||
torchdynamo.reset()
|
torchdynamo.reset()
|
||||||
del trainer
|
del trainer
|
||||||
|
|
||||||
# 2. TorchDynamo nvfuser
|
# 2. TorchDynamo nvfuser
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
|
a = torch.ones(1024, 1024, device=torch_device, requires_grad=True)
|
||||||
a.grad = None
|
a.grad = None
|
||||||
args = TrainingArguments(output_dir=tmp_dir, torch_compile_backend="nvfuser")
|
args = TrainingArguments(output_dir=tmp_dir, torch_compile_backend="nvfuser")
|
||||||
trainer = CustomTrainer(model=mod, args=args)
|
trainer = CustomTrainer(model=mod, args=args)
|
||||||
@@ -4135,11 +4136,11 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
# resets
|
# resets
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
torch.cuda.reset_peak_memory_stats()
|
backend_reset_peak_memory_stats(torch_device)
|
||||||
|
|
||||||
loss = trainer.training_step(mod, {"x": a})
|
loss = trainer.training_step(mod, {"x": a})
|
||||||
peak_mem = torch.cuda.max_memory_allocated()
|
peak_mem = backend_max_memory_allocated(torch_device)
|
||||||
torchdynamo.reset()
|
torchdynamo.reset()
|
||||||
del trainer
|
del trainer
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from transformers import set_seed
|
|||||||
from transformers.generation.configuration_utils import ALL_CACHE_IMPLEMENTATIONS
|
from transformers.generation.configuration_utils import ALL_CACHE_IMPLEMENTATIONS
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
CaptureStderr,
|
CaptureStderr,
|
||||||
|
backend_device_count,
|
||||||
cleanup,
|
cleanup,
|
||||||
get_gpu_count,
|
get_gpu_count,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -210,8 +211,8 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
if not has_accelerator:
|
if not has_accelerator:
|
||||||
self.skipTest("Offloaded caches require an accelerator")
|
self.skipTest("Offloaded caches require an accelerator")
|
||||||
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
|
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
|
||||||
if torch.cuda.device_count() != 1:
|
if backend_device_count(torch_device) != 1:
|
||||||
self.skipTest("Offloaded static caches require exactly 1 GPU")
|
self.skipTest("Offloaded static caches require exactly 1 accelerator")
|
||||||
|
|
||||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||||
def test_cache_batched(self, cache_implementation):
|
def test_cache_batched(self, cache_implementation):
|
||||||
|
|||||||
Reference in New Issue
Block a user