Device agnostic testing (#25870)

* adds agnostic decorators and availability fns

* renaming decorators and fixing imports

* updating some representative example tests
bloom, opt, and reformer for now

* wip device agnostic functions

* lru cache to device checking functions

* adds `TRANSFORMERS_TEST_DEVICE_SPEC`
if present, imports the target file and updates device to function
mappings

* comments `TRANSFORMERS_TEST_DEVICE_SPEC` code

* extra checks on device name

* `make style; make quality`

* updates default functions for agnostic calls

* applies suggestions from review

* adds `is_torch_available` guard

* Add spec file to docs, rename function dispatch names to backend_*

* add backend import to docs example for spec file

* change instances of  to

* Move register backend to before device check as per @statelesshz changes

* make style

* make opt test require fp16 to run

---------

Co-authored-by: arsalanu <arsalanu@graphcore.ai>
Co-authored-by: arsalanu <hzji210@gmail.com>
This commit is contained in:
Alex McKinney
2023-10-24 15:49:26 +01:00
committed by GitHub
parent 41496b95da
commit 9da451713d
8 changed files with 188 additions and 25 deletions

View File

@@ -20,6 +20,7 @@ from transformers.testing_utils import (
require_sentencepiece,
require_tokenizers,
require_torch,
require_torch_fp16,
require_torch_multi_gpu,
slow,
torch_device,
@@ -563,12 +564,12 @@ class ReformerTesterMixin:
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_random_seed(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
@require_torch_fp16
def test_reformer_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
@require_torch_fp16
def test_reformer_model_fp16_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)