Fix TF CTC tests (#21606)
This commit is contained in:
@@ -321,6 +321,20 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
model = TFHubertModel.from_pretrained("facebook/hubert-base-ls960")
|
model = TFHubertModel.from_pretrained("facebook/hubert-base-ls960")
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||||
|
def test_dataset_conversion(self):
|
||||||
|
default_batch_size = self.model_tester.batch_size
|
||||||
|
self.model_tester.batch_size = 2
|
||||||
|
super().test_dataset_conversion()
|
||||||
|
self.model_tester.batch_size = default_batch_size
|
||||||
|
|
||||||
|
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||||
|
def test_keras_fit(self):
|
||||||
|
default_batch_size = self.model_tester.batch_size
|
||||||
|
self.model_tester.batch_size = 2
|
||||||
|
super().test_keras_fit()
|
||||||
|
self.model_tester.batch_size = default_batch_size
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
@@ -431,20 +445,18 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||||
# TODO: fix me
|
|
||||||
@unittest.skip(reason="Crashing on CI, temporarily skipped")
|
|
||||||
def test_dataset_conversion(self):
|
def test_dataset_conversion(self):
|
||||||
default_batch_size = self.model_tester.batch_size
|
default_batch_size = self.model_tester.batch_size
|
||||||
self.model_tester.batch_size = 2
|
self.model_tester.batch_size = 2
|
||||||
super().test_dataset_conversion()
|
super().test_dataset_conversion()
|
||||||
self.model_tester.batch_size = default_batch_size
|
self.model_tester.batch_size = default_batch_size
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_model_from_pretrained(self):
|
|
||||||
model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
|
|
||||||
self.assertIsNotNone(model)
|
|
||||||
|
|
||||||
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||||
def test_keras_fit(self):
|
def test_keras_fit(self):
|
||||||
default_batch_size = self.model_tester.batch_size
|
default_batch_size = self.model_tester.batch_size
|
||||||
|
|||||||
@@ -396,7 +396,7 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
def test_keras_fit(self):
|
def test_keras_fit(self):
|
||||||
default_batch_size = self.model_tester.batch_size
|
default_batch_size = self.model_tester.batch_size
|
||||||
self.model_tester.batch_size = 2
|
self.model_tester.batch_size = 2
|
||||||
super().test_dataset_conversion()
|
super().test_keras_fit()
|
||||||
self.model_tester.batch_size = default_batch_size
|
self.model_tester.batch_size = default_batch_size
|
||||||
|
|
||||||
|
|
||||||
@@ -527,7 +527,7 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
def test_keras_fit(self):
|
def test_keras_fit(self):
|
||||||
default_batch_size = self.model_tester.batch_size
|
default_batch_size = self.model_tester.batch_size
|
||||||
self.model_tester.batch_size = 2
|
self.model_tester.batch_size = 2
|
||||||
super().test_dataset_conversion()
|
super().test_keras_fit()
|
||||||
self.model_tester.batch_size = default_batch_size
|
self.model_tester.batch_size = default_batch_size
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user