From e1f6e4903a15f6eef27fa568e2ebf7ac3a4fce49 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Wed, 8 Sep 2021 19:51:51 +0300 Subject: [PATCH] Fix integration tests for TFWav2Vec2 and TFHubert --- tests/test_modeling_tf_hubert.py | 16 +++++++--------- tests/test_modeling_tf_wav2vec2.py | 8 +++----- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/test_modeling_tf_hubert.py b/tests/test_modeling_tf_hubert.py index f97f0c10a7..7e85519553 100644 --- a/tests/test_modeling_tf_hubert.py +++ b/tests/test_modeling_tf_hubert.py @@ -511,14 +511,12 @@ class TFHubertModelIntegrationTest(unittest.TestCase): self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) def test_inference_ctc_normal_batched(self): - model = TFHubertForCTC.from_pretrained("facebook/hubert-base-ls960") - processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-base-ls960", do_lower_case=True) + model = TFHubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft") + processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft", do_lower_case=True) input_speech = self._load_datasamples(2) - input_values = processor( - input_speech, return_tensors="tf", padding=True, truncation=True, sampling_rate=16000 - ).input_values + input_values = processor(input_speech, return_tensors="tf", padding=True, sampling_rate=16000).input_values logits = model(input_values).logits @@ -527,7 +525,7 @@ class TFHubertModelIntegrationTest(unittest.TestCase): EXPECTED_TRANSCRIPTIONS = [ "a man said to the universe sir i exist", - "sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore", + "sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore", ] self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) @@ -537,20 +535,20 @@ class TFHubertModelIntegrationTest(unittest.TestCase): input_speech = self._load_datasamples(4) - inputs = processor(input_speech, return_tensors="tf", padding=True, truncation=True) + inputs = processor(input_speech, return_tensors="tf", padding=True, sampling_rate=16000) input_values = inputs.input_values attention_mask = inputs.attention_mask logits = model(input_values, attention_mask=attention_mask).logits - predicted_ids = tf.argmax(logits, dim=-1) + predicted_ids = tf.argmax(logits, axis=-1) predicted_trans = processor.batch_decode(predicted_ids) EXPECTED_TRANSCRIPTIONS = [ "a man said to the universe sir i exist", "sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore", "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about", - "his instant panic was followed by a small sharp blow high on his chest", + "his instant of panic was followed by a small sharp blow high on his chest", ] self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) diff --git a/tests/test_modeling_tf_wav2vec2.py b/tests/test_modeling_tf_wav2vec2.py index 889790c75e..c7844ad4b2 100644 --- a/tests/test_modeling_tf_wav2vec2.py +++ b/tests/test_modeling_tf_wav2vec2.py @@ -516,9 +516,7 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): input_speech = self._load_datasamples(2) - input_values = processor( - input_speech, return_tensors="tf", padding=True, truncation=True, sampling_rate=16000 - ).input_values + input_values = processor(input_speech, return_tensors="tf", padding=True, sampling_rate=16000).input_values logits = model(input_values).logits @@ -537,14 +535,14 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): input_speech = self._load_datasamples(4) - inputs = processor(input_speech, return_tensors="tf", padding=True, truncation=True) + inputs = processor(input_speech, return_tensors="tf", padding=True, sampling_rate=16000) input_values = inputs.input_values attention_mask = inputs.attention_mask logits = model(input_values, attention_mask=attention_mask).logits - predicted_ids = tf.argmax(logits, dim=-1) + predicted_ids = tf.argmax(logits, axis=-1) predicted_trans = processor.batch_decode(predicted_ids) EXPECTED_TRANSCRIPTIONS = [