[Benchmark] add tpu and torchscipt for benchmark (#4850)
* add tpu and torchscipt for benchmark * fix name in tests * "fix email" * make style * better log message for tpu * add more print and info for tpu * allow possibility to print tpu metrics * correct cpu usage * fix test for non-install * remove bugus file * include psutil in testing * run a couple of times before tracing in torchscript * do not allow tpu memory tracing for now * make style * add torchscript to env * better name for torch tpu Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
f0340b3031
commit
2cfb947f59
@@ -68,6 +68,21 @@ except ImportError:
|
||||
torch_cache_home = os.path.expanduser(
|
||||
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
tpu_device = xm.xla_device()
|
||||
|
||||
if _torch_available:
|
||||
_torch_tpu_available = True # pylint: disable=
|
||||
else:
|
||||
_torch_tpu_available = False
|
||||
except ImportError:
|
||||
_torch_tpu_available = False
|
||||
|
||||
|
||||
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||
|
||||
|
||||
@@ -98,6 +113,10 @@ def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
||||
def is_torch_tpu_available():
|
||||
return _torch_tpu_available
|
||||
|
||||
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
||||
|
||||
Reference in New Issue
Block a user