Device agnostic testing (#25870)
* adds agnostic decorators and availability fns * renaming decorators and fixing imports * updating some representative example tests bloom, opt, and reformer for now * wip device agnostic functions * lru cache to device checking functions * adds `TRANSFORMERS_TEST_DEVICE_SPEC` if present, imports the target file and updates device to function mappings * comments `TRANSFORMERS_TEST_DEVICE_SPEC` code * extra checks on device name * `make style; make quality` * updates default functions for agnostic calls * applies suggestions from review * adds `is_torch_available` guard * Add spec file to docs, rename function dispatch names to backend_* * add backend import to docs example for spec file * change instances of to * Move register backend to before device check as per @statelesshz changes * make style * make opt test require fp16 to run --------- Co-authored-by: arsalanu <arsalanu@graphcore.ai> Co-authored-by: arsalanu <hzji210@gmail.com>
This commit is contained in:
@@ -525,6 +525,25 @@ Certain devices will require an additional import after importing `torch` for th
|
||||
```bash
|
||||
TRANSFORMERS_TEST_BACKEND="torch_npu" pytest tests/utils/test_logging.py
|
||||
```
|
||||
Alternative backends may also require the replacement of device-specific functions. For example `torch.cuda.manual_seed` may need to be replaced with a device-specific seed setter like `torch.npu.manual_seed` to correctly set a random seed on the device. To specify a new backend with backend-specific device functions when running the test suite, create a Python device specification file in the format:
|
||||
|
||||
```
|
||||
import torch
|
||||
import torch_npu
|
||||
# !! Further additional imports can be added here !!
|
||||
|
||||
# Specify the device name (eg. 'cuda', 'cpu', 'npu')
|
||||
DEVICE_NAME = 'npu'
|
||||
|
||||
# Specify device-specific backends to dispatch to.
|
||||
# If not specified, will fallback to 'default' in 'testing_utils.py`
|
||||
MANUAL_SEED_FN = torch.npu.manual_seed
|
||||
EMPTY_CACHE_FN = torch.npu.empty_cache
|
||||
DEVICE_COUNT_FN = torch.npu.device_count
|
||||
```
|
||||
This format also allows for specification of any additional imports required. To use this file to replace equivalent methods in the test suite, set the environment variable `TRANSFORMERS_TEST_DEVICE_SPEC` to the path of the spec file.
|
||||
|
||||
Currently, only `MANUAL_SEED_FN`, `EMPTY_CACHE_FN` and `DEVICE_COUNT_FN` are supported for device-specific dispatch.
|
||||
|
||||
|
||||
### Distributed training
|
||||
|
||||
Reference in New Issue
Block a user