Use the Keras set_random_seed in tests (#30504)
Use the Keras set_random_seed to ensure reproducible weight initialization
This commit is contained in:
@@ -541,11 +541,10 @@ class PipelineUtilsTest(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_load_default_pipelines_tf(self):
|
def test_load_default_pipelines_tf(self):
|
||||||
import tensorflow as tf
|
from transformers.modeling_tf_utils import keras
|
||||||
|
|
||||||
from transformers.pipelines import SUPPORTED_TASKS
|
from transformers.pipelines import SUPPORTED_TASKS
|
||||||
|
|
||||||
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
|
set_seed_fn = lambda: keras.utils.set_random_seed(0) # noqa: E731
|
||||||
for task in SUPPORTED_TASKS.keys():
|
for task in SUPPORTED_TASKS.keys():
|
||||||
if task == "table-question-answering":
|
if task == "table-question-answering":
|
||||||
# test table in seperate test due to more dependencies
|
# test table in seperate test due to more dependencies
|
||||||
@@ -553,7 +552,7 @@ class PipelineUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
|
self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
|
||||||
|
|
||||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
# clean-up as much as possible GPU memory occupied by TF
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
Reference in New Issue
Block a user