Add SEW CTC models (#14158)

* Add SEW CTC models

* Update paths

* Update paths
This commit is contained in:
Anton Lozhkov
2021-10-27 12:21:09 +03:00
committed by GitHub
parent 1e53faeb2e
commit e1dc5afd28
8 changed files with 106 additions and 116 deletions

View File

@@ -22,7 +22,7 @@ import pytest
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import SEWConfig, is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, tooslow, torch_device
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, _config_zero_init
@@ -531,27 +531,24 @@ class SEWModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 5)
@tooslow
def test_inference_ctc_batched(self):
# TODO: enable this test once the finetuned models are available
model = SEWForCTC.from_pretrained("asapp/sew-tiny-100k-ft-100h").to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k-ft-100h", do_lower_case=True)
model = SEWForCTC.from_pretrained("asapp/sew-tiny-100k-ft-ls100h").to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k-ft-ls100h", do_lower_case=True)
input_speech = self._load_datasamples(2)
inputs = processor(input_speech, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
logits = model(input_values, attention_mask=attention_mask).logits
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
"swet covered brian's body trickling into the tightloine closs hat was the only garment he wore",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)