Option to benchmark only one of the two libraries
This commit is contained in:
@@ -14,7 +14,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Benchmarking the library on inference and training """
|
""" Benchmarking the library on inference and training """
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
# If checking the tensors placement
|
# If checking the tensors placement
|
||||||
# tf.debugging.set_log_device_placement(True)
|
# tf.debugging.set_log_device_placement(True)
|
||||||
@@ -23,15 +22,18 @@ from typing import List
|
|||||||
import timeit
|
import timeit
|
||||||
from transformers import is_tf_available, is_torch_available
|
from transformers import is_tf_available, is_torch_available
|
||||||
from time import time
|
from time import time
|
||||||
import torch
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import csv
|
import csv
|
||||||
|
|
||||||
if not is_torch_available() or not is_tf_available():
|
if is_tf_available():
|
||||||
raise ImportError("TensorFlow and Pytorch should be installed on the system.")
|
import tensorflow as tf
|
||||||
|
from transformers import TFAutoModel
|
||||||
|
|
||||||
from transformers import AutoConfig, AutoModel, AutoTokenizer, TFAutoModel
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModel
|
||||||
|
|
||||||
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
|
||||||
input_text = """Bent over their instruments, three hundred Fertilizers were plunged, as
|
input_text = """Bent over their instruments, three hundred Fertilizers were plunged, as
|
||||||
the Director of Hatcheries and Conditioning entered the room, in the
|
the Director of Hatcheries and Conditioning entered the room, in the
|
||||||
@@ -434,6 +436,7 @@ def main():
|
|||||||
print("Running with arguments", args)
|
print("Running with arguments", args)
|
||||||
|
|
||||||
if args.torch:
|
if args.torch:
|
||||||
|
if is_torch_available():
|
||||||
create_setup_and_compute(
|
create_setup_and_compute(
|
||||||
model_names=args.models,
|
model_names=args.models,
|
||||||
tensorflow=False,
|
tensorflow=False,
|
||||||
@@ -443,8 +446,11 @@ def main():
|
|||||||
csv_filename=args.csv_filename,
|
csv_filename=args.csv_filename,
|
||||||
average_over=args.average_over
|
average_over=args.average_over
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
|
||||||
|
|
||||||
if args.tensorflow:
|
if args.tensorflow:
|
||||||
|
if is_tf_available():
|
||||||
create_setup_and_compute(
|
create_setup_and_compute(
|
||||||
model_names=args.models,
|
model_names=args.models,
|
||||||
tensorflow=True,
|
tensorflow=True,
|
||||||
@@ -453,7 +459,8 @@ def main():
|
|||||||
csv_filename=args.csv_filename,
|
csv_filename=args.csv_filename,
|
||||||
average_over=args.average_over
|
average_over=args.average_over
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user