Replace input_values_processing with unpack_inputs (#21502)

* Replace input_values_prrocessing with unpack_inputs

* Skip test failing with OOM

* Update tests
This commit is contained in:
amyeroberts
2023-02-10 18:19:39 +00:00
committed by GitHub
parent 557125637d
commit cb56590111
4 changed files with 96 additions and 427 deletions

View File

@@ -304,18 +304,15 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
# Hubert has no inputs_embeds
@unittest.skip(reason="Hubert has no input embeddings")
def test_inputs_embeds(self):
pass
# Hubert cannot resize token embeddings
# since it has no tokens embeddings
@unittest.skip(reason="Hubert has no tokens embeddings")
def test_resize_tokens_embeddings(self):
pass
# Hubert has no inputs_embeds
# and thus the `get_input_embeddings` fn
# is not implemented
@unittest.skip(reason="Hubert has no input embeddings")
def test_model_common_attributes(self):
pass
@@ -324,10 +321,6 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFHubertModel.from_pretrained("facebook/hubert-base-ls960")
self.assertIsNotNone(model)
@unittest.skip("Loss shapes for CTC don't match the base test.")
def test_loss_computation(self):
pass
@require_tf
class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
@@ -426,29 +419,36 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
# Hubert has no inputs_embeds
@unittest.skip(reason="Hubert has no input embeddings")
def test_inputs_embeds(self):
pass
# Hubert cannot resize token embeddings
# since it has no tokens embeddings
@unittest.skip(reason="Hubert has no tokens embeddings")
def test_resize_tokens_embeddings(self):
pass
# Hubert has no inputs_embeds
# and thus the `get_input_embeddings` fn
# is not implemented
@unittest.skip(reason="Hubert has no input embeddings or get_input_embeddings method")
def test_model_common_attributes(self):
pass
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_dataset_conversion(self):
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size
@slow
def test_model_from_pretrained(self):
model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
self.assertIsNotNone(model)
@unittest.skip("Loss shapes for CTC don't match the base test.")
def test_loss_computation(self):
pass
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_keras_fit(self):
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_keras_fit()
self.model_tester.batch_size = default_batch_size
@require_tf

View File

@@ -369,18 +369,15 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs)
# Wav2Vec2 has no inputs_embeds
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
def test_inputs_embeds(self):
pass
# Wav2Vec2 cannot resize token embeddings
# since it has no tokens embeddings
@unittest.skip(reason="Wav2Vec2 has no tokens embeddings")
def test_resize_tokens_embeddings(self):
pass
# Wav2Vec2 has no inputs_embeds
# and thus the `get_input_embeddings` fn
# is not implemented
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
def test_model_common_attributes(self):
pass
@@ -389,13 +386,19 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsNotNone(model)
@unittest.skip(reason="Dataset conversion goes OOM and crashes with the default options!")
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_dataset_conversion(self):
pass
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size
@unittest.skip(reason="Training goes OOM and crashes with the default options!")
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_keras_fit(self):
pass
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size
@require_tf
@@ -497,18 +500,15 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs)
# Wav2Vec2 has no inputs_embeds
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
def test_inputs_embeds(self):
pass
# Wav2Vec2 cannot resize token embeddings
# since it has no tokens embeddings
@unittest.skip(reason="Wav2Vec2 has no tokens embeddings")
def test_resize_tokens_embeddings(self):
pass
# Wav2Vec2 has no inputs_embeds
# and thus the `get_input_embeddings` fn
# is not implemented
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
def test_model_common_attributes(self):
pass
@@ -517,13 +517,19 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsNotNone(model)
@unittest.skip(reason="Dataset conversion goes OOM and crashes with the default options!")
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_dataset_conversion(self):
pass
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size
@unittest.skip(reason="Training goes OOM and crashes with the default options!")
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_keras_fit(self):
pass
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size
@require_tf