[s2s] test_distributed_eval (#8315)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -297,6 +297,22 @@ def require_ray(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def get_gpu_count():
|
||||
"""
|
||||
Return the number of available gpus (regardless of whether torch or tf is used)
|
||||
"""
|
||||
if _torch_available:
|
||||
import torch
|
||||
|
||||
return torch.cuda.device_count()
|
||||
elif _tf_available:
|
||||
import tensorflow as tf
|
||||
|
||||
return len(tf.config.list_physical_devices("GPU"))
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_tests_dir(append_path=None):
|
||||
"""
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user