Device agnostic testing (#25870)
* adds agnostic decorators and availability fns * renaming decorators and fixing imports * updating some representative example tests bloom, opt, and reformer for now * wip device agnostic functions * lru cache to device checking functions * adds `TRANSFORMERS_TEST_DEVICE_SPEC` if present, imports the target file and updates device to function mappings * comments `TRANSFORMERS_TEST_DEVICE_SPEC` code * extra checks on device name * `make style; make quality` * updates default functions for agnostic calls * applies suggestions from review * adds `is_torch_available` guard * Add spec file to docs, rename function dispatch names to backend_* * add backend import to docs example for spec file * change instances of to * Move register backend to before device check as per @statelesshz changes * make style * make opt test require fp16 to run --------- Co-authored-by: arsalanu <arsalanu@graphcore.ai> Co-authored-by: arsalanu <hzji210@gmail.com>
This commit is contained in:
@@ -32,7 +32,7 @@ import unittest
|
||||
from collections.abc import Mapping
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Iterator, List, Optional, Union
|
||||
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
|
||||
from unittest import mock
|
||||
|
||||
import huggingface_hub
|
||||
@@ -98,8 +98,10 @@ from .utils import (
|
||||
is_timm_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_torch_bf16_available_on_device,
|
||||
is_torch_bf16_cpu_available,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_fp16_available_on_device,
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_tensorrt_fx_available,
|
||||
@@ -713,6 +715,16 @@ if is_torch_available():
|
||||
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
|
||||
import torch
|
||||
|
||||
if "TRANSFORMERS_TEST_BACKEND" in os.environ:
|
||||
backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
|
||||
try:
|
||||
_ = importlib.import_module(backend)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
|
||||
f" traceback):\n{e}"
|
||||
) from e
|
||||
|
||||
if "TRANSFORMERS_TEST_DEVICE" in os.environ:
|
||||
torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
|
||||
try:
|
||||
@@ -730,17 +742,6 @@ if is_torch_available():
|
||||
torch_device = "xpu"
|
||||
else:
|
||||
torch_device = "cpu"
|
||||
|
||||
if "TRANSFORMERS_TEST_BACKEND" in os.environ:
|
||||
backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
|
||||
try:
|
||||
_ = importlib.import_module(backend)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
|
||||
f" traceback):\n{e}"
|
||||
) from e
|
||||
|
||||
else:
|
||||
torch_device = None
|
||||
|
||||
@@ -770,6 +771,25 @@ def require_torch_gpu(test_case):
|
||||
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator(test_case):
|
||||
"""Decorator marking a test that requires an accessible accelerator and PyTorch."""
|
||||
return unittest.skipUnless(torch_device != "cpu", "test requires accelerator")(test_case)
|
||||
|
||||
|
||||
def require_torch_fp16(test_case):
|
||||
"""Decorator marking a test that requires a device that supports fp16"""
|
||||
return unittest.skipUnless(
|
||||
is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_bf16(test_case):
|
||||
"""Decorator marking a test that requires a device that supports bf16"""
|
||||
return unittest.skipUnless(
|
||||
is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_bf16_gpu(test_case):
|
||||
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
|
||||
return unittest.skipUnless(
|
||||
@@ -2176,3 +2196,86 @@ class HfDoctestModule(Module):
|
||||
for test in finder.find(module, module.__name__):
|
||||
if test.examples: # skip empty doctests and cuda
|
||||
yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test)
|
||||
|
||||
|
||||
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
|
||||
if device not in dispatch_table:
|
||||
return dispatch_table["default"](*args, **kwargs)
|
||||
|
||||
fn = dispatch_table[device]
|
||||
|
||||
# Some device agnostic functions return values. Need to guard against `None`
|
||||
# instead at user level.
|
||||
if fn is None:
|
||||
return None
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
# Mappings from device names to callable functions to support device agnostic
|
||||
# testing.
|
||||
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
|
||||
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "default": None}
|
||||
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "default": lambda: 1}
|
||||
|
||||
|
||||
def backend_manual_seed(device: str, seed: int):
|
||||
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
|
||||
|
||||
|
||||
def backend_empty_cache(device: str):
|
||||
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
|
||||
|
||||
|
||||
def backend_device_count(device: str):
|
||||
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
# If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries
|
||||
# into device to function mappings.
|
||||
if "TRANSFORMERS_TEST_DEVICE_SPEC" in os.environ:
|
||||
device_spec_path = os.environ["TRANSFORMERS_TEST_DEVICE_SPEC"]
|
||||
if not Path(device_spec_path).is_file():
|
||||
raise ValueError(
|
||||
f"Specified path to device spec file is not a file or not found. Received '{device_spec_path}"
|
||||
)
|
||||
|
||||
# Try to strip extension for later import – also verifies we are importing a
|
||||
# python file.
|
||||
try:
|
||||
import_name = device_spec_path[: device_spec_path.index(".py")]
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Provided device spec file was not a Python file! Received '{device_spec_path}") from e
|
||||
|
||||
device_spec_module = importlib.import_module(import_name)
|
||||
|
||||
# Imported file must contain `DEVICE_NAME`. If it doesn't, terminate early.
|
||||
try:
|
||||
device_name = device_spec_module.DEVICE_NAME
|
||||
except AttributeError as e:
|
||||
raise AttributeError("Device spec file did not contain `DEVICE_NAME`") from e
|
||||
|
||||
if "TRANSFORMERS_TEST_DEVICE" in os.environ and torch_device != device_name:
|
||||
msg = f"Mismatch between environment variable `TRANSFORMERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
|
||||
msg += "Either unset `TRANSFORMERS_TEST_DEVICE` or ensure it matches device spec name."
|
||||
raise ValueError(msg)
|
||||
|
||||
torch_device = device_name
|
||||
|
||||
def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):
|
||||
try:
|
||||
# Try to import the function directly
|
||||
spec_fn = getattr(device_spec_module, attribute_name)
|
||||
device_fn_dict[torch_device] = spec_fn
|
||||
except AttributeError as e:
|
||||
# If the function doesn't exist, and there is no default, throw an error
|
||||
if "default" not in device_fn_dict:
|
||||
raise AttributeError(
|
||||
f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
|
||||
) from e
|
||||
|
||||
# Add one entry here for each `BACKEND_*` dictionary.
|
||||
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
|
||||
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
|
||||
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
|
||||
|
||||
Reference in New Issue
Block a user