Add SEW CTC models (#14158)
* Add SEW CTC models * Update paths * Update paths
This commit is contained in:
@@ -22,7 +22,7 @@ import pytest
|
||||
|
||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from transformers import SEWConfig, is_torch_available
|
||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, tooslow, torch_device
|
||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
@@ -531,27 +531,24 @@ class SEWModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
|
||||
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 5)
|
||||
|
||||
@tooslow
|
||||
def test_inference_ctc_batched(self):
|
||||
# TODO: enable this test once the finetuned models are available
|
||||
model = SEWForCTC.from_pretrained("asapp/sew-tiny-100k-ft-100h").to(torch_device)
|
||||
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k-ft-100h", do_lower_case=True)
|
||||
model = SEWForCTC.from_pretrained("asapp/sew-tiny-100k-ft-ls100h").to(torch_device)
|
||||
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k-ft-ls100h", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_values, attention_mask=attention_mask).logits
|
||||
logits = model(input_values).logits
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"a man said to the universe sir i exist",
|
||||
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
|
||||
"swet covered brian's body trickling into the tightloine closs hat was the only garment he wore",
|
||||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
Reference in New Issue
Block a user