From 1cfd9748683db43af2c98da1a19d39f0efc8cc3b Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 22 Oct 2019 13:32:23 -0400 Subject: [PATCH] Option to benchmark only one of the two libraries --- examples/benchmarks.py | 55 ++++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/examples/benchmarks.py b/examples/benchmarks.py index b1153bf566..d03844697d 100644 --- a/examples/benchmarks.py +++ b/examples/benchmarks.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Benchmarking the library on inference and training """ -import tensorflow as tf # If checking the tensors placement # tf.debugging.set_log_device_placement(True) @@ -23,15 +22,18 @@ from typing import List import timeit from transformers import is_tf_available, is_torch_available from time import time -import torch - import argparse import csv -if not is_torch_available() or not is_tf_available(): - raise ImportError("TensorFlow and Pytorch should be installed on the system.") +if is_tf_available(): + import tensorflow as tf + from transformers import TFAutoModel -from transformers import AutoConfig, AutoModel, AutoTokenizer, TFAutoModel +if is_torch_available(): + import torch + from transformers import AutoModel + +from transformers import AutoConfig, AutoTokenizer input_text = """Bent over their instruments, three hundred Fertilizers were plunged, as the Director of Hatcheries and Conditioning entered the room, in the @@ -434,26 +436,31 @@ def main(): print("Running with arguments", args) if args.torch: - create_setup_and_compute( - model_names=args.models, - tensorflow=False, - gpu=args.torch_cuda, - torchscript=args.torchscript, - save_to_csv=args.save_to_csv, - csv_filename=args.csv_filename, - average_over=args.average_over - ) + if is_torch_available(): + create_setup_and_compute( + model_names=args.models, + tensorflow=False, + gpu=args.torch_cuda, + torchscript=args.torchscript, + save_to_csv=args.save_to_csv, + csv_filename=args.csv_filename, + average_over=args.average_over + ) + else: + raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.") if args.tensorflow: - create_setup_and_compute( - model_names=args.models, - tensorflow=True, - xla=args.xla, - save_to_csv=args.save_to_csv, - csv_filename=args.csv_filename, - average_over=args.average_over - ) - + if is_tf_available(): + create_setup_and_compute( + model_names=args.models, + tensorflow=True, + xla=args.xla, + save_to_csv=args.save_to_csv, + csv_filename=args.csv_filename, + average_over=args.average_over + ) + else: + raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.") if __name__ == '__main__': main()