[Benchmark] add tpu and torchscipt for benchmark (#4850)

* add tpu and torchscipt for benchmark

* fix name in tests

* "fix email"

* make style

* better log message for tpu

* add more print and info for tpu

* allow possibility to print tpu metrics

* correct cpu usage

* fix test for non-install

* remove bugus file

* include psutil in testing

* run a couple of times before tracing in torchscript

* do not allow tpu memory tracing for now

* make style

* add torchscript to env

* better name for torch tpu

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen
2020-06-09 23:12:43 +02:00
committed by GitHub
parent f0340b3031
commit 2cfb947f59
9 changed files with 317 additions and 84 deletions

View File

@@ -68,6 +68,21 @@ except ImportError:
torch_cache_home = os.path.expanduser(
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
)
try:
import torch_xla.core.xla_model as xm
tpu_device = xm.xla_device()
if _torch_available:
_torch_tpu_available = True # pylint: disable=
else:
_torch_tpu_available = False
except ImportError:
_torch_tpu_available = False
default_cache_path = os.path.join(torch_cache_home, "transformers")
@@ -98,6 +113,10 @@ def is_tf_available():
return _tf_available
def is_torch_tpu_available():
return _torch_tpu_available
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")