[tests] remove flax-pt equivalence and cross tests (#36283)
This commit is contained in:
@@ -31,13 +31,11 @@ from datasets import Dataset
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import ( # noqa: F401
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
_tf_gpu_memory_limit,
|
||||
require_tf,
|
||||
require_tf2onnx,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
|
||||
from transformers.utils.generic import ModelOutput
|
||||
@@ -73,20 +71,6 @@ if is_tf_available():
|
||||
|
||||
tf.config.experimental.enable_tensor_float_32_execution(False)
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
for gpu in gpus:
|
||||
# Restrict TensorFlow to only allocate x GB of memory on the GPUs
|
||||
try:
|
||||
tf.config.set_logical_device_configuration(
|
||||
gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
|
||||
)
|
||||
logical_gpus = tf.config.list_logical_devices("GPU")
|
||||
print("Logical GPUs", logical_gpus)
|
||||
except RuntimeError as e:
|
||||
# Virtual devices must be set before GPUs have been initialized
|
||||
print(e)
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
|
||||
Reference in New Issue
Block a user