[tests] remove flax-pt equivalence and cross tests (#36283)

This commit is contained in:
Joao Gante
2025-02-19 15:13:27 +00:00
committed by GitHub
parent fa8cdccd91
commit 99adc74462
39 changed files with 33 additions and 3103 deletions

View File

@@ -31,13 +31,11 @@ from datasets import Dataset
from transformers import is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import ( # noqa: F401
from transformers.testing_utils import (
CaptureLogger,
_tf_gpu_memory_limit,
require_tf,
require_tf2onnx,
slow,
torch_device,
)
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
from transformers.utils.generic import ModelOutput
@@ -73,20 +71,6 @@ if is_tf_available():
tf.config.experimental.enable_tensor_float_32_execution(False)
if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU")
for gpu in gpus:
# Restrict TensorFlow to only allocate x GB of memory on the GPUs
try:
tf.config.set_logical_device_configuration(
gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
)
logical_gpus = tf.config.list_logical_devices("GPU")
print("Logical GPUs", logical_gpus)
except RuntimeError as e:
# Virtual devices must be set before GPUs have been initialized
print(e)
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)