Device agnostic testing (#25870)

* adds agnostic decorators and availability fns

* renaming decorators and fixing imports

* updating some representative example tests
bloom, opt, and reformer for now

* wip device agnostic functions

* lru cache to device checking functions

* adds `TRANSFORMERS_TEST_DEVICE_SPEC`
if present, imports the target file and updates device to function
mappings

* comments `TRANSFORMERS_TEST_DEVICE_SPEC` code

* extra checks on device name

* `make style; make quality`

* updates default functions for agnostic calls

* applies suggestions from review

* adds `is_torch_available` guard

* Add spec file to docs, rename function dispatch names to backend_*

* add backend import to docs example for spec file

* change instances of  to

* Move register backend to before device check as per @statelesshz changes

* make style

* make opt test require fp16 to run

---------

Co-authored-by: arsalanu <arsalanu@graphcore.ai>
Co-authored-by: arsalanu <hzji210@gmail.com>
This commit is contained in:
Alex McKinney
2023-10-24 15:49:26 +01:00
committed by GitHub
parent 41496b95da
commit 9da451713d
8 changed files with 188 additions and 25 deletions

View File

@@ -32,7 +32,7 @@ import unittest
from collections.abc import Mapping
from io import StringIO
from pathlib import Path
from typing import Iterable, Iterator, List, Optional, Union
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
from unittest import mock
import huggingface_hub
@@ -98,8 +98,10 @@ from .utils import (
is_timm_available,
is_tokenizers_available,
is_torch_available,
is_torch_bf16_available_on_device,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_fp16_available_on_device,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tensorrt_fx_available,
@@ -713,6 +715,16 @@ if is_torch_available():
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
import torch
if "TRANSFORMERS_TEST_BACKEND" in os.environ:
backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
try:
_ = importlib.import_module(backend)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
f" traceback):\n{e}"
) from e
if "TRANSFORMERS_TEST_DEVICE" in os.environ:
torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
try:
@@ -730,17 +742,6 @@ if is_torch_available():
torch_device = "xpu"
else:
torch_device = "cpu"
if "TRANSFORMERS_TEST_BACKEND" in os.environ:
backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
try:
_ = importlib.import_module(backend)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
f" traceback):\n{e}"
) from e
else:
torch_device = None
@@ -770,6 +771,25 @@ def require_torch_gpu(test_case):
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
def require_torch_accelerator(test_case):
"""Decorator marking a test that requires an accessible accelerator and PyTorch."""
return unittest.skipUnless(torch_device != "cpu", "test requires accelerator")(test_case)
def require_torch_fp16(test_case):
"""Decorator marking a test that requires a device that supports fp16"""
return unittest.skipUnless(
is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support"
)(test_case)
def require_torch_bf16(test_case):
"""Decorator marking a test that requires a device that supports bf16"""
return unittest.skipUnless(
is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support"
)(test_case)
def require_torch_bf16_gpu(test_case):
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
return unittest.skipUnless(
@@ -2176,3 +2196,86 @@ class HfDoctestModule(Module):
for test in finder.find(module, module.__name__):
if test.examples: # skip empty doctests and cuda
yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test)
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
if device not in dispatch_table:
return dispatch_table["default"](*args, **kwargs)
fn = dispatch_table[device]
# Some device agnostic functions return values. Need to guard against `None`
# instead at user level.
if fn is None:
return None
return fn(*args, **kwargs)
if is_torch_available():
# Mappings from device names to callable functions to support device agnostic
# testing.
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "default": None}
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "default": lambda: 1}
def backend_manual_seed(device: str, seed: int):
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
def backend_empty_cache(device: str):
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
def backend_device_count(device: str):
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
if is_torch_available():
# If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries
# into device to function mappings.
if "TRANSFORMERS_TEST_DEVICE_SPEC" in os.environ:
device_spec_path = os.environ["TRANSFORMERS_TEST_DEVICE_SPEC"]
if not Path(device_spec_path).is_file():
raise ValueError(
f"Specified path to device spec file is not a file or not found. Received '{device_spec_path}"
)
# Try to strip extension for later import also verifies we are importing a
# python file.
try:
import_name = device_spec_path[: device_spec_path.index(".py")]
except ValueError as e:
raise ValueError(f"Provided device spec file was not a Python file! Received '{device_spec_path}") from e
device_spec_module = importlib.import_module(import_name)
# Imported file must contain `DEVICE_NAME`. If it doesn't, terminate early.
try:
device_name = device_spec_module.DEVICE_NAME
except AttributeError as e:
raise AttributeError("Device spec file did not contain `DEVICE_NAME`") from e
if "TRANSFORMERS_TEST_DEVICE" in os.environ and torch_device != device_name:
msg = f"Mismatch between environment variable `TRANSFORMERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
msg += "Either unset `TRANSFORMERS_TEST_DEVICE` or ensure it matches device spec name."
raise ValueError(msg)
torch_device = device_name
def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):
try:
# Try to import the function directly
spec_fn = getattr(device_spec_module, attribute_name)
device_fn_dict[torch_device] = spec_fn
except AttributeError as e:
# If the function doesn't exist, and there is no default, throw an error
if "default" not in device_fn_dict:
raise AttributeError(
f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
) from e
# Add one entry here for each `BACKEND_*` dictionary.
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")