Add Flax example tests (#14599)
* add test for glue * add tests for clm * fix clm test * add summrization tests * more tests * fix few tests * add test for t5 mlm * fix t5 mlm test * fix tests for multi device * cleanup * ci job * fix metric file name * make t5 more robust
This commit is contained in:
@@ -600,7 +600,7 @@ def require_deepspeed(test_case):
|
||||
|
||||
def get_gpu_count():
|
||||
"""
|
||||
Return the number of available gpus (regardless of whether torch or tf is used)
|
||||
Return the number of available gpus (regardless of whether torch, tf or jax is used)
|
||||
"""
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@@ -610,6 +610,10 @@ def get_gpu_count():
|
||||
import tensorflow as tf
|
||||
|
||||
return len(tf.config.list_physical_devices("GPU"))
|
||||
elif is_flax_available():
|
||||
import jax
|
||||
|
||||
return jax.device_count()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user