Replace input_values_processing with unpack_inputs (#21502)
* Replace input_values_prrocessing with unpack_inputs * Skip test failing with OOM * Update tests
This commit is contained in:
@@ -369,18 +369,15 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# Wav2Vec2 cannot resize token embeddings
|
||||
# since it has no tokens embeddings
|
||||
@unittest.skip(reason="Wav2Vec2 has no tokens embeddings")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
# and thus the `get_input_embeddings` fn
|
||||
# is not implemented
|
||||
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@@ -389,13 +386,19 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip(reason="Dataset conversion goes OOM and crashes with the default options!")
|
||||
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||
def test_dataset_conversion(self):
|
||||
pass
|
||||
default_batch_size = self.model_tester.batch_size
|
||||
self.model_tester.batch_size = 2
|
||||
super().test_dataset_conversion()
|
||||
self.model_tester.batch_size = default_batch_size
|
||||
|
||||
@unittest.skip(reason="Training goes OOM and crashes with the default options!")
|
||||
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||
def test_keras_fit(self):
|
||||
pass
|
||||
default_batch_size = self.model_tester.batch_size
|
||||
self.model_tester.batch_size = 2
|
||||
super().test_dataset_conversion()
|
||||
self.model_tester.batch_size = default_batch_size
|
||||
|
||||
|
||||
@require_tf
|
||||
@@ -497,18 +500,15 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# Wav2Vec2 cannot resize token embeddings
|
||||
# since it has no tokens embeddings
|
||||
@unittest.skip(reason="Wav2Vec2 has no tokens embeddings")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
# and thus the `get_input_embeddings` fn
|
||||
# is not implemented
|
||||
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@@ -517,13 +517,19 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip(reason="Dataset conversion goes OOM and crashes with the default options!")
|
||||
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||
def test_dataset_conversion(self):
|
||||
pass
|
||||
default_batch_size = self.model_tester.batch_size
|
||||
self.model_tester.batch_size = 2
|
||||
super().test_dataset_conversion()
|
||||
self.model_tester.batch_size = default_batch_size
|
||||
|
||||
@unittest.skip(reason="Training goes OOM and crashes with the default options!")
|
||||
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||
def test_keras_fit(self):
|
||||
pass
|
||||
default_batch_size = self.model_tester.batch_size
|
||||
self.model_tester.batch_size = 2
|
||||
super().test_dataset_conversion()
|
||||
self.model_tester.batch_size = default_batch_size
|
||||
|
||||
|
||||
@require_tf
|
||||
|
||||
Reference in New Issue
Block a user