Adds TRANSFORMERS_TEST_BACKEND (#25655)

* Adds `TRANSFORMERS_TEST_BACKEND`
Allows specifying arbitrary additional import following first `import torch`.
This is useful for some custom backends, that will require additional imports to trigger backend registration with upstream torch.
See https://github.com/pytorch/benchmark/pull/1805 for a similar change in `torchbench`.

* Update src/transformers/testing_utils.py

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

* Adds real backend example to documentation

---------

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Alex McKinney
2023-08-22 16:08:13 +01:00
committed by GitHub
parent fd56f7f081
commit 5eeaef921f
2 changed files with 19 additions and 2 deletions

View File

@@ -16,6 +16,7 @@ import collections
import contextlib
import doctest
import functools
import importlib
import inspect
import logging
import multiprocessing
@@ -642,6 +643,17 @@ if is_torch_available():
torch_device = "npu"
else:
torch_device = "cpu"
if "TRANSFORMERS_TEST_BACKEND" in os.environ:
backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
try:
_ = importlib.import_module(backend)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
f" traceback):\n{e}"
) from e
else:
torch_device = None