[s2s] test_distributed_eval (#8315)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Stas Bekman
2020-11-05 13:01:15 -08:00
committed by GitHub
parent 04e442d575
commit d787935a14
4 changed files with 56 additions and 8 deletions

View File

@@ -297,6 +297,22 @@ def require_ray(test_case):
return test_case
def get_gpu_count():
"""
Return the number of available gpus (regardless of whether torch or tf is used)
"""
if _torch_available:
import torch
return torch.cuda.device_count()
elif _tf_available:
import tensorflow as tf
return len(tf.config.list_physical_devices("GPU"))
else:
return 0
def get_tests_dir(append_path=None):
"""
Args: