Add Flax example tests (#14599)

* add test for glue

* add tests for clm

* fix clm test

* add summrization tests

* more tests

* fix few tests

* add test for t5 mlm

* fix t5 mlm test

* fix tests for multi device

* cleanup

* ci job

* fix metric file name

* make t5 more robust
This commit is contained in:
Suraj Patil
2021-12-06 10:48:58 +05:30
committed by GitHub
parent 803a8cd18f
commit c5bd732ac6
11 changed files with 553 additions and 6 deletions

View File

@@ -600,7 +600,7 @@ def require_deepspeed(test_case):
def get_gpu_count():
"""
Return the number of available gpus (regardless of whether torch or tf is used)
Return the number of available gpus (regardless of whether torch, tf or jax is used)
"""
if is_torch_available():
import torch
@@ -610,6 +610,10 @@ def get_gpu_count():
import tensorflow as tf
return len(tf.config.list_physical_devices("GPU"))
elif is_flax_available():
import jax
return jax.device_count()
else:
return 0