[Benchmark] add tpu and torchscipt for benchmark (#4850)
* add tpu and torchscipt for benchmark * fix name in tests * "fix email" * make style * better log message for tpu * add more print and info for tpu * allow possibility to print tpu metrics * correct cpu usage * fix test for non-install * remove bugus file * include psutil in testing * run a couple of times before tracing in torchscript * do not allow tpu memory tracing for now * make style * add torchscript to env * better name for torch tpu Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
f0340b3031
commit
2cfb947f59
@@ -5,25 +5,15 @@ import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from .file_utils import cached_property, is_torch_available, torch_required
|
||||
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
try:
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
_has_tpu = True
|
||||
except ImportError:
|
||||
_has_tpu = False
|
||||
|
||||
|
||||
@torch_required
|
||||
def is_tpu_available():
|
||||
return _has_tpu
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -176,7 +166,7 @@ class TrainingArguments:
|
||||
if self.no_cuda:
|
||||
device = torch.device("cpu")
|
||||
n_gpu = 0
|
||||
elif is_tpu_available():
|
||||
elif is_torch_tpu_available():
|
||||
device = xm.xla_device()
|
||||
n_gpu = 0
|
||||
elif self.local_rank == -1:
|
||||
|
||||
Reference in New Issue
Block a user