Fix integration tests for TFWav2Vec2 and TFHubert
This commit is contained in:
@@ -511,14 +511,12 @@ class TFHubertModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||||
|
|
||||||
def test_inference_ctc_normal_batched(self):
|
def test_inference_ctc_normal_batched(self):
|
||||||
model = TFHubertForCTC.from_pretrained("facebook/hubert-base-ls960")
|
model = TFHubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
|
||||||
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-base-ls960", do_lower_case=True)
|
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft", do_lower_case=True)
|
||||||
|
|
||||||
input_speech = self._load_datasamples(2)
|
input_speech = self._load_datasamples(2)
|
||||||
|
|
||||||
input_values = processor(
|
input_values = processor(input_speech, return_tensors="tf", padding=True, sampling_rate=16000).input_values
|
||||||
input_speech, return_tensors="tf", padding=True, truncation=True, sampling_rate=16000
|
|
||||||
).input_values
|
|
||||||
|
|
||||||
logits = model(input_values).logits
|
logits = model(input_values).logits
|
||||||
|
|
||||||
@@ -527,7 +525,7 @@ class TFHubertModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
EXPECTED_TRANSCRIPTIONS = [
|
EXPECTED_TRANSCRIPTIONS = [
|
||||||
"a man said to the universe sir i exist",
|
"a man said to the universe sir i exist",
|
||||||
"sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore",
|
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
|
||||||
]
|
]
|
||||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||||
|
|
||||||
@@ -537,20 +535,20 @@ class TFHubertModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
input_speech = self._load_datasamples(4)
|
input_speech = self._load_datasamples(4)
|
||||||
|
|
||||||
inputs = processor(input_speech, return_tensors="tf", padding=True, truncation=True)
|
inputs = processor(input_speech, return_tensors="tf", padding=True, sampling_rate=16000)
|
||||||
|
|
||||||
input_values = inputs.input_values
|
input_values = inputs.input_values
|
||||||
attention_mask = inputs.attention_mask
|
attention_mask = inputs.attention_mask
|
||||||
|
|
||||||
logits = model(input_values, attention_mask=attention_mask).logits
|
logits = model(input_values, attention_mask=attention_mask).logits
|
||||||
|
|
||||||
predicted_ids = tf.argmax(logits, dim=-1)
|
predicted_ids = tf.argmax(logits, axis=-1)
|
||||||
predicted_trans = processor.batch_decode(predicted_ids)
|
predicted_trans = processor.batch_decode(predicted_ids)
|
||||||
|
|
||||||
EXPECTED_TRANSCRIPTIONS = [
|
EXPECTED_TRANSCRIPTIONS = [
|
||||||
"a man said to the universe sir i exist",
|
"a man said to the universe sir i exist",
|
||||||
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
|
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
|
||||||
"the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
|
"the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
|
||||||
"his instant panic was followed by a small sharp blow high on his chest",
|
"his instant of panic was followed by a small sharp blow high on his chest",
|
||||||
]
|
]
|
||||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||||
|
|||||||
@@ -516,9 +516,7 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
input_speech = self._load_datasamples(2)
|
input_speech = self._load_datasamples(2)
|
||||||
|
|
||||||
input_values = processor(
|
input_values = processor(input_speech, return_tensors="tf", padding=True, sampling_rate=16000).input_values
|
||||||
input_speech, return_tensors="tf", padding=True, truncation=True, sampling_rate=16000
|
|
||||||
).input_values
|
|
||||||
|
|
||||||
logits = model(input_values).logits
|
logits = model(input_values).logits
|
||||||
|
|
||||||
@@ -537,14 +535,14 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
input_speech = self._load_datasamples(4)
|
input_speech = self._load_datasamples(4)
|
||||||
|
|
||||||
inputs = processor(input_speech, return_tensors="tf", padding=True, truncation=True)
|
inputs = processor(input_speech, return_tensors="tf", padding=True, sampling_rate=16000)
|
||||||
|
|
||||||
input_values = inputs.input_values
|
input_values = inputs.input_values
|
||||||
attention_mask = inputs.attention_mask
|
attention_mask = inputs.attention_mask
|
||||||
|
|
||||||
logits = model(input_values, attention_mask=attention_mask).logits
|
logits = model(input_values, attention_mask=attention_mask).logits
|
||||||
|
|
||||||
predicted_ids = tf.argmax(logits, dim=-1)
|
predicted_ids = tf.argmax(logits, axis=-1)
|
||||||
predicted_trans = processor.batch_decode(predicted_ids)
|
predicted_trans = processor.batch_decode(predicted_ids)
|
||||||
|
|
||||||
EXPECTED_TRANSCRIPTIONS = [
|
EXPECTED_TRANSCRIPTIONS = [
|
||||||
|
|||||||
Reference in New Issue
Block a user