From 2cfb947f59861d5d910f84eba3be57da200b5599 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 9 Jun 2020 23:12:43 +0200 Subject: [PATCH] [Benchmark] add tpu and torchscipt for benchmark (#4850) * add tpu and torchscipt for benchmark * fix name in tests * "fix email" * make style * better log message for tpu * add more print and info for tpu * allow possibility to print tpu metrics * correct cpu usage * fix test for non-install * remove bugus file * include psutil in testing * run a couple of times before tracing in torchscript * do not allow tpu memory tracing for now * make style * add torchscript to env * better name for torch tpu Co-authored-by: Patrick von Platen --- setup.py | 2 +- src/transformers/__init__.py | 1 + src/transformers/benchmark/benchmark.py | 139 ++++++++++++----- src/transformers/benchmark/benchmark_args.py | 17 +- src/transformers/benchmark/benchmark_utils.py | 146 +++++++++++++++++- src/transformers/file_utils.py | 19 +++ src/transformers/trainer.py | 30 ++-- src/transformers/training_args.py | 16 +- tests/test_benchmark.py | 31 ++++ 9 files changed, 317 insertions(+), 84 deletions(-) diff --git a/setup.py b/setup.py index b4e5bbdb0c..dbf24ce4e5 100644 --- a/setup.py +++ b/setup.py @@ -84,7 +84,7 @@ extras["torch"] = ["torch"] extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"] extras["all"] = extras["serving"] + ["tensorflow", "torch"] -extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator"] +extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "psutil"] extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme"] extras["quality"] = [ "black", diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 34dd885d86..ca973acec5 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -78,6 +78,7 @@ from .file_utils import ( cached_path, is_tf_available, is_torch_available, + is_torch_tpu_available, ) from .hf_argparser import HfArgumentParser diff --git a/src/transformers/benchmark/benchmark.py b/src/transformers/benchmark/benchmark.py index 742d84b773..d9a8a4c8c3 100644 --- a/src/transformers/benchmark/benchmark.py +++ b/src/transformers/benchmark/benchmark.py @@ -19,12 +19,17 @@ import logging -import os import timeit -from transformers import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, PretrainedConfig, is_torch_available +from transformers import ( + MODEL_MAPPING, + MODEL_WITH_LM_HEAD_MAPPING, + PretrainedConfig, + is_torch_available, + is_torch_tpu_available, +) -from .benchmark_utils import Benchmark, Memory, start_memory_tracing, stop_memory_tracing +from .benchmark_utils import Benchmark, Memory, measure_peak_memory_cpu, start_memory_tracing, stop_memory_tracing if is_torch_available(): @@ -48,6 +53,10 @@ class PyTorchBenchmark(Benchmark): def train(self, model_name, batch_size, sequence_length, trace_memory=False): try: config = self.config_dict[model_name] + + if self.args.torchscript: + config.torchscript = True + model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config) model.to(self.args.device) model.train() @@ -58,15 +67,20 @@ class PyTorchBenchmark(Benchmark): vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device ) + if self.args.torchscript: + raise NotImplementedError("Training for torchscript is currently not implemented") + else: + train_model = model + def compute_loss_and_backprob_encoder(): - loss = model(input_ids, labels=input_ids)[0] + loss = train_model(input_ids, labels=input_ids)[0] loss.backward() - model.zero_grad() + train_model.zero_grad() def compute_loss_and_backprob_encoder_decoder(): - loss = model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0] + loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0] loss.backward() - model.zero_grad() + train_model.zero_grad() _train = ( compute_loss_and_backprob_encoder_decoder @@ -79,6 +93,7 @@ class PyTorchBenchmark(Benchmark): trace = start_memory_tracing("transformers") if self.args.n_gpu > 0: + # gpu # clear gpu cache torch.cuda.empty_cache() if hasattr(torch.cuda, "max_memory_reserved"): @@ -89,8 +104,17 @@ class PyTorchBenchmark(Benchmark): ) torch.cuda.reset_max_memory_cached() - # calculate loss and do backpropagation - _train() + # calculate loss and do backpropagation + _train() + elif not self.args.no_tpu and is_torch_tpu_available(): + # tpu + raise NotImplementedError( + "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`" + ) + else: + # cpu + memory_bytes = measure_peak_memory_cpu(_train) + memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes if self.args.trace_memory_line_by_line: summary = stop_memory_tracing(trace) @@ -107,40 +131,47 @@ class PyTorchBenchmark(Benchmark): ) memory = Memory(torch.cuda.max_memory_cached()) memory = Memory(torch.cuda.max_memory_reserved()) - else: - # cpu - try: - import psutil - except (ImportError): - logger.warning( - "Psutil not installed, we won't log CPU memory usage. " - "Install psutil (pip install psutil) to use CPU memory tracing." - ) - memory = "N/A" - else: - process = psutil.Process(os.getpid()) - memory = Memory(process.memory_info().rss) return memory, summary else: + if (not self.args.no_tpu and is_torch_tpu_available()) or self.args.torchscript: + # run additional 10 times to stabilize compilation for tpu and torchscript + logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation") + timeit.repeat( + _train, repeat=1, number=5, + ) + # as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average runtimes = timeit.repeat(_train, repeat=self.args.repeat, number=10,) + + if not self.args.no_tpu and is_torch_tpu_available() and self.args.tpu_print_metrics: + import torch_xla.debug.metrics as met + + self.print_fn(met.metrics_report()) + return min(runtimes) / 10.0 except RuntimeError as e: self.print_fn("Doesn't fit on GPU. {}".format(e)) - return "N/A" + if trace_memory: + return "N/A", None + else: + return "N/A" def inference(self, model_name, batch_size, sequence_length, trace_memory=False): try: config = self.config_dict[model_name] + model = None + + if self.args.torchscript: + config.torchscript = True if self.args.with_lm_head: model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config) else: model = MODEL_MAPPING[config.__class__](config) - model.to(self.args.device) model.eval() + model.to(self.args.device) # encoder-decoder has vocab size saved differently vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size @@ -149,11 +180,22 @@ class PyTorchBenchmark(Benchmark): vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device ) + if self.args.torchscript: + with torch.no_grad(): + if config.is_encoder_decoder: + raise NotImplementedError("Torchscript is currently not supported for EncoderDecoder models") + else: + inference_model = torch.jit.trace(model, input_ids) + else: + inference_model = model + def encoder_decoder_forward(): - model(input_ids, decoder_input_ids=input_ids) + with torch.no_grad(): + inference_model(input_ids, decoder_input_ids=input_ids) def encoder_forward(): - model(input_ids) + with torch.no_grad(): + inference_model(input_ids) _forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward @@ -162,6 +204,7 @@ class PyTorchBenchmark(Benchmark): trace = start_memory_tracing("transformers") if self.args.n_gpu > 0: + # gpu # clear gpu cache torch.cuda.empty_cache() if hasattr(torch.cuda, "max_memory_reserved"): @@ -172,7 +215,17 @@ class PyTorchBenchmark(Benchmark): ) torch.cuda.reset_max_memory_cached() - _forward() + # run forward + _forward() + elif not self.args.no_tpu and is_torch_tpu_available(): + # tpu + raise NotImplementedError( + "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`" + ) + else: + # cpu + memory_bytes = measure_peak_memory_cpu(_forward) + memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes if self.args.trace_memory_line_by_line: summary = stop_memory_tracing(trace) @@ -188,26 +241,30 @@ class PyTorchBenchmark(Benchmark): "Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage" ) memory = Memory(torch.cuda.max_memory_cached()) - else: - # cpu - try: - import psutil - except (ImportError): - logger.warning( - "Psutil not installed, we won't log CPU memory usage. " - "Install psutil (pip install psutil) to use CPU memory tracing." - ) - memory = "N/A" - else: - process = psutil.Process(os.getpid()) - memory = Memory(process.memory_info().rss) return memory, summary else: + + if (not self.args.no_tpu and is_torch_tpu_available()) or self.args.torchscript: + # run additional 10 times to stabilize compilation for tpu and torchscript + logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation") + timeit.repeat( + _forward, repeat=1, number=5, + ) + # as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average runtimes = timeit.repeat(_forward, repeat=self.args.repeat, number=10,) + + if not self.args.no_tpu and is_torch_tpu_available() and self.args.tpu_print_metrics: + import torch_xla.debug.metrics as met + + self.print_fn(met.metrics_report()) + return min(runtimes) / 10.0 except RuntimeError as e: self.print_fn("Doesn't fit on GPU. {}".format(e)) - return "N/A" + if trace_memory: + return "N/A", None + else: + return "N/A" diff --git a/src/transformers/benchmark/benchmark_args.py b/src/transformers/benchmark/benchmark_args.py index 46e62fe368..0cc043537b 100644 --- a/src/transformers/benchmark/benchmark_args.py +++ b/src/transformers/benchmark/benchmark_args.py @@ -18,25 +18,16 @@ import logging from dataclasses import dataclass, field from typing import Tuple -from ..file_utils import cached_property, is_torch_available, torch_required +from ..file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required from .benchmark_args_utils import BenchmarkArguments if is_torch_available(): import torch -try: +if is_torch_tpu_available(): import torch_xla.core.xla_model as xm - _has_tpu = True -except ImportError: - _has_tpu = False - - -@torch_required -def is_tpu_available(): - return _has_tpu - logger = logging.getLogger(__name__) @@ -45,7 +36,9 @@ logger = logging.getLogger(__name__) class PyTorchBenchmarkArguments(BenchmarkArguments): no_cuda: bool = field(default=False, metadata={"help": "Whether to run on available cuda devices"}) torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"}) + no_tpu: bool = field(default=False, metadata={"help": "Whether to run on available tpu devices"}) fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."}) + tpu_print_metrics: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."}) @cached_property @torch_required @@ -54,7 +47,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments): if self.no_cuda: device = torch.device("cpu") n_gpu = 0 - elif is_tpu_available(): + elif is_torch_tpu_available(): device = xm.xla_device() n_gpu = 0 else: diff --git a/src/transformers/benchmark/benchmark_utils.py b/src/transformers/benchmark/benchmark_utils.py index 7b3c3304b1..4b40dafa9d 100644 --- a/src/transformers/benchmark/benchmark_utils.py +++ b/src/transformers/benchmark/benchmark_utils.py @@ -14,12 +14,15 @@ import sys from abc import ABC, abstractmethod from collections import defaultdict, namedtuple from datetime import datetime -from typing import Iterable, List, NamedTuple, Optional, Union +from multiprocessing import Pipe, Process +from multiprocessing.connection import Connection +from signal import SIGKILL +from typing import Callable, Iterable, List, NamedTuple, Optional, Union from transformers import AutoConfig, PretrainedConfig from transformers import __version__ as version -from ..file_utils import is_tf_available, is_torch_available +from ..file_utils import is_tf_available, is_torch_available, is_torch_tpu_available from .benchmark_args_utils import BenchmarkArguments @@ -128,6 +131,127 @@ class MemorySummary(NamedTuple): MemoryTrace = List[UsedMemoryState] +def measure_peak_memory_cpu(function: Callable[[], None], interval=0.5) -> int: + """ + measures peak cpu memory consumption of a given `function` + running the function for at least interval seconds + and at most 20 * interval seconds. + This function is heavily inspired by: `memory_usage` + of the package `memory_profiler`: https://github.com/pythonprofilers/memory_profiler/blob/895c4ac7a08020d66ae001e24067da6dcea42451/memory_profiler.py#L239 + + Args: + - `function`: (`callable`): function() -> ... + function without any arguments to measure for which to measure the peak memory + + - `interval`: (`float`) + interval in second for which to measure the memory usage + + Returns: + - `max_memory`: (`int`) + cosumed memory peak in Bytes + """ + try: + import psutil + except (ImportError): + logger.warning( + "Psutil not installed, we won't log CPU memory usage. " + "Install Psutil (pip install psutil) to use CPU memory tracing." + ) + max_memory = "N/A" + else: + + def _get_memory(process_id: int) -> int: + """ + measures current cpu memory usage of a given `process_id` + + Args: + - `process_id`: (`int`) + process_id for which to measure memory + + Returns + - `memory`: (`int`) + cosumed memory in Bytes + """ + process = psutil.Process(process_id) + try: + meminfo_attr = "memory_info" if hasattr(process, "memory_info") else "get_memory_info" + memory = getattr(process, meminfo_attr)()[0] + except psutil.AccessDenied: + raise ValueError("Error with Psutil.") + return memory + + class MemoryMeasureProcess(Process): + + """ + `MemoryMeasureProcess` inherits from `Process` and overwrites + its `run()` method. Used to measure the memory usage of a process + """ + + def __init__(self, process_id: int, child_connection: Connection, interval: float): + super().__init__() + self.process_id = process_id + self.interval = interval + self.connection = child_connection + self.num_measurements = 1 + self.mem_usage = _get_memory(process_id) + + def run(self): + self.connection.send(0) + stop = False + while True: + self.mem_usage = max(self.mem_usage, _get_memory(self.process_id)) + self.num_measurements += 1 + + if stop: + break + + stop = self.connection.poll(self.interval) + + # send results to parent pipe + self.connection.send(self.mem_usage) + self.connection.send(self.num_measurements) + + while True: + # create child, parent connection + child_connection, parent_connection = Pipe() + + # instantiate process + mem_process = MemoryMeasureProcess(os.getpid(), child_connection, interval) + mem_process.start() + + # wait until we get memory + parent_connection.recv() + + try: + # execute function + function() + + # start parent connection + parent_connection.send(0) + + # receive memory and num measurements + max_memory = parent_connection.recv() + num_measurements = parent_connection.recv() + except Exception: + # kill process in a clean way + parent = psutil.Process(os.getpid()) + for child in parent.children(recursive=True): + os.kill(child.pid, SIGKILL) + mem_process.join(0) + raise RuntimeError("Process killed. Error in Process") + + # run process at least 20 * interval or until it finishes + mem_process.join(20 * interval) + + if (num_measurements > 4) or (interval < 1e-6): + break + + # reduce interval + interval /= 10 + + return max_memory + + def start_memory_tracing( modules_to_trace: Optional[Union[str, Iterable[str]]] = None, modules_not_to_trace: Optional[Union[str, Iterable[str]]] = None, @@ -424,6 +548,10 @@ class Benchmark(ABC): def is_gpu(self): return self.args.n_gpu > 0 + @property + def is_tpu(self): + return is_torch_tpu_available() and not self.args.no_tpu + @property @abstractmethod def framework_version(self): @@ -486,6 +614,10 @@ class Benchmark(ABC): self.print_fn("======= INFERENCE - SPEED - RESULT =======") self.print_results(inference_result_time) self.save_to_csv(inference_result_time, self.args.inference_time_csv_file) + if self.is_tpu: + self.print_fn( + "TPU was used for inference. Note that the time after compilation stabilized (after ~10 inferences model.forward(..) calls) was measured." + ) if not self.args.no_memory: self.print_fn("======= INFERENCE - MEMORY - RESULT =======") @@ -501,6 +633,10 @@ class Benchmark(ABC): self.print_fn("======= TRAIN - SPEED - RESULT =======") self.print_results(train_result_time) self.save_to_csv(train_result_time, self.args.train_time_csv_file) + if self.is_tpu: + self.print_fn( + "TPU was used for training. Note that the time after compilation stabilized (after ~10 train loss=model.forward(...) + loss.backward() calls) was measured." + ) if not self.args.no_memory: self.print_fn("======= TRAIN - MEMORY - RESULT =======") @@ -538,6 +674,8 @@ class Benchmark(ABC): info = {} info["transformers_version"] = version info["framework"] = self.framework + if self.framework == "PyTorch": + info["use_torchscript"] = self.args.torchscript info["framework_version"] = self.framework_version info["python_version"] = platform.python_version() info["system"] = platform.system() @@ -590,6 +728,10 @@ class Benchmark(ABC): info["gpu_performance_state"] = py3nvml.nvmlDeviceGetPerformanceState(handle) py3nvml.nvmlShutdown() + info["use_tpu"] = self.is_tpu + # TODO(PVP): See if we can add more information about TPU + # see: https://github.com/pytorch/xla/issues/2180 + self._environment_info = info return self._environment_info diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index fa9e17e833..a6925aa082 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -68,6 +68,21 @@ except ImportError: torch_cache_home = os.path.expanduser( os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) ) + + +try: + import torch_xla.core.xla_model as xm + + tpu_device = xm.xla_device() + + if _torch_available: + _torch_tpu_available = True # pylint: disable= + else: + _torch_tpu_available = False +except ImportError: + _torch_tpu_available = False + + default_cache_path = os.path.join(torch_cache_home, "transformers") @@ -98,6 +113,10 @@ def is_tf_available(): return _tf_available +def is_torch_tpu_available(): + return _torch_tpu_available + + def add_start_docstrings(*docstr): def docstring_decorator(fn): fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2caccdded9..d642a22295 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -23,7 +23,7 @@ from .data.data_collator import DataCollator, DefaultDataCollator from .modeling_utils import PreTrainedModel from .optimization import AdamW, get_linear_schedule_with_warmup from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput -from .training_args import TrainingArguments, is_tpu_available +from .training_args import TrainingArguments, is_torch_tpu_available try: @@ -38,7 +38,7 @@ def is_apex_available(): return _has_apex -if is_tpu_available(): +if is_torch_tpu_available(): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met import torch_xla.distributed.parallel_loader as pl @@ -218,7 +218,7 @@ class Trainer: # Create output directory if needed if self.is_world_master(): os.makedirs(self.args.output_dir, exist_ok=True) - if is_tpu_available(): + if is_torch_tpu_available(): # Set an xla_device flag on the model's config. # We'll find a more elegant and not need to do this in the future. self.model.config.xla_device = True @@ -226,7 +226,7 @@ class Trainer: def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") - if is_tpu_available(): + if is_torch_tpu_available(): train_sampler = get_tpu_sampler(self.train_dataset) else: train_sampler = ( @@ -251,7 +251,7 @@ class Trainer: eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - if is_tpu_available(): + if is_torch_tpu_available(): sampler = SequentialDistributedSampler( eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() ) @@ -272,7 +272,7 @@ class Trainer: def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: # We use the same batch_size as for eval. - if is_tpu_available(): + if is_torch_tpu_available(): sampler = SequentialDistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() ) @@ -407,7 +407,7 @@ class Trainer: self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={}) # Train! - if is_tpu_available(): + if is_torch_tpu_available(): total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() else: total_train_batch_size = ( @@ -455,7 +455,7 @@ class Trainer: if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) - if is_tpu_available(): + if is_torch_tpu_available(): parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( self.args.device ) @@ -482,7 +482,7 @@ class Trainer: else: torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) - if is_tpu_available(): + if is_torch_tpu_available(): xm.optimizer_step(optimizer) else: optimizer.step() @@ -525,7 +525,7 @@ class Trainer: if self.is_world_master(): self._rotate_checkpoints() - if is_tpu_available(): + if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) @@ -588,7 +588,7 @@ class Trainer: return loss.item() def is_local_master(self) -> bool: - if is_tpu_available(): + if is_torch_tpu_available(): return xm.is_master_ordinal(local=True) else: return self.args.local_rank in [-1, 0] @@ -598,7 +598,7 @@ class Trainer: This will be True only in one process, even in distributed mode, even when training on multiple machines. """ - if is_tpu_available(): + if is_torch_tpu_available(): return xm.is_master_ordinal(local=False) else: return self.args.local_rank == -1 or torch.distributed.get_rank() == 0 @@ -611,7 +611,7 @@ class Trainer: Will only save from the world_master process (unless in TPUs). """ - if is_tpu_available(): + if is_torch_tpu_available(): self._save_tpu(output_dir) elif self.is_world_master(): self._save(output_dir) @@ -746,7 +746,7 @@ class Trainer: label_ids: torch.Tensor = None model.eval() - if is_tpu_available(): + if is_torch_tpu_available(): dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) for inputs in tqdm(dataloader, desc=description): @@ -780,7 +780,7 @@ class Trainer: preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader)) if label_ids is not None: label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader)) - elif is_tpu_available(): + elif is_torch_tpu_available(): # tpu-comment: Get all predictions and labels from all worker shards of eval dataset if preds is not None: preds = xm.mesh_reduce("eval_preds", preds, torch.cat) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index da6862f2c9..0277d0034c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -5,25 +5,15 @@ import os from dataclasses import dataclass, field from typing import Any, Dict, Optional, Tuple -from .file_utils import cached_property, is_torch_available, torch_required +from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required if is_torch_available(): import torch - -try: +if is_torch_tpu_available(): import torch_xla.core.xla_model as xm - _has_tpu = True -except ImportError: - _has_tpu = False - - -@torch_required -def is_tpu_available(): - return _has_tpu - logger = logging.getLogger(__name__) @@ -176,7 +166,7 @@ class TrainingArguments: if self.no_cuda: device = torch.device("cpu") n_gpu = 0 - elif is_tpu_available(): + elif is_torch_tpu_available(): device = xm.xla_device() n_gpu = 0 elif self.local_rank == -1: diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 20e830461e..b891582c50 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -33,6 +33,21 @@ class BenchmarkTest(unittest.TestCase): self.check_results_dict_not_empty(results.time_inference_result) self.check_results_dict_not_empty(results.memory_inference_result) + def test_inference_torchscript(self): + MODEL_ID = "sshleifer/tiny-gpt2" + benchmark_args = PyTorchBenchmarkArguments( + models=[MODEL_ID], + training=False, + no_inference=False, + torchscript=True, + sequence_lengths=[8], + batch_sizes=[1], + ) + benchmark = PyTorchBenchmark(benchmark_args) + results = benchmark.run() + self.check_results_dict_not_empty(results.time_inference_result) + self.check_results_dict_not_empty(results.memory_inference_result) + def test_train_no_configs(self): MODEL_ID = "sshleifer/tiny-gpt2" benchmark_args = PyTorchBenchmarkArguments( @@ -76,6 +91,22 @@ class BenchmarkTest(unittest.TestCase): self.check_results_dict_not_empty(results.time_train_result) self.check_results_dict_not_empty(results.memory_train_result) + def test_train_with_configs_torchscript(self): + MODEL_ID = "sshleifer/tiny-gpt2" + config = AutoConfig.from_pretrained(MODEL_ID) + benchmark_args = PyTorchBenchmarkArguments( + models=[MODEL_ID], + training=True, + no_inference=True, + torchscript=True, + sequence_lengths=[8], + batch_sizes=[1], + ) + benchmark = PyTorchBenchmark(benchmark_args, configs=[config]) + results = benchmark.run() + self.check_results_dict_not_empty(results.time_train_result) + self.check_results_dict_not_empty(results.memory_train_result) + def test_train_encoder_decoder_with_configs(self): MODEL_ID = "sshleifer/tinier_bart" config = AutoConfig.from_pretrained(MODEL_ID)