Adds TRANSFORMERS_TEST_BACKEND (#25655)
* Adds `TRANSFORMERS_TEST_BACKEND` Allows specifying arbitrary additional import following first `import torch`. This is useful for some custom backends, that will require additional imports to trigger backend registration with upstream torch. See https://github.com/pytorch/benchmark/pull/1805 for a similar change in `torchbench`. * Update src/transformers/testing_utils.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * Adds real backend example to documentation --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,7 @@ import collections
|
||||
import contextlib
|
||||
import doctest
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import multiprocessing
|
||||
@@ -642,6 +643,17 @@ if is_torch_available():
|
||||
torch_device = "npu"
|
||||
else:
|
||||
torch_device = "cpu"
|
||||
|
||||
if "TRANSFORMERS_TEST_BACKEND" in os.environ:
|
||||
backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
|
||||
try:
|
||||
_ = importlib.import_module(backend)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
|
||||
f" traceback):\n{e}"
|
||||
) from e
|
||||
|
||||
else:
|
||||
torch_device = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user