Option to benchmark only one of the two libraries

This commit is contained in:
Lysandre
2019-10-22 13:32:23 -04:00
parent 777faa8ae7
commit 1cfd974868

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Benchmarking the library on inference and training """ """ Benchmarking the library on inference and training """
import tensorflow as tf
# If checking the tensors placement # If checking the tensors placement
# tf.debugging.set_log_device_placement(True) # tf.debugging.set_log_device_placement(True)
@@ -23,15 +22,18 @@ from typing import List
import timeit import timeit
from transformers import is_tf_available, is_torch_available from transformers import is_tf_available, is_torch_available
from time import time from time import time
import torch
import argparse import argparse
import csv import csv
if not is_torch_available() or not is_tf_available(): if is_tf_available():
raise ImportError("TensorFlow and Pytorch should be installed on the system.") import tensorflow as tf
from transformers import TFAutoModel
from transformers import AutoConfig, AutoModel, AutoTokenizer, TFAutoModel if is_torch_available():
import torch
from transformers import AutoModel
from transformers import AutoConfig, AutoTokenizer
input_text = """Bent over their instruments, three hundred Fertilizers were plunged, as input_text = """Bent over their instruments, three hundred Fertilizers were plunged, as
the Director of Hatcheries and Conditioning entered the room, in the the Director of Hatcheries and Conditioning entered the room, in the
@@ -434,26 +436,31 @@ def main():
print("Running with arguments", args) print("Running with arguments", args)
if args.torch: if args.torch:
create_setup_and_compute( if is_torch_available():
model_names=args.models, create_setup_and_compute(
tensorflow=False, model_names=args.models,
gpu=args.torch_cuda, tensorflow=False,
torchscript=args.torchscript, gpu=args.torch_cuda,
save_to_csv=args.save_to_csv, torchscript=args.torchscript,
csv_filename=args.csv_filename, save_to_csv=args.save_to_csv,
average_over=args.average_over csv_filename=args.csv_filename,
) average_over=args.average_over
)
else:
raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
if args.tensorflow: if args.tensorflow:
create_setup_and_compute( if is_tf_available():
model_names=args.models, create_setup_and_compute(
tensorflow=True, model_names=args.models,
xla=args.xla, tensorflow=True,
save_to_csv=args.save_to_csv, xla=args.xla,
csv_filename=args.csv_filename, save_to_csv=args.save_to_csv,
average_over=args.average_over csv_filename=args.csv_filename,
) average_over=args.average_over
)
else:
raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.")
if __name__ == '__main__': if __name__ == '__main__':
main() main()