[MMS] Fix mms (#25267)

* [MMS] Fix mms

* [MMS] Fix mms

* fix mms loading

* Apply suggestions from code review

* make style

* Update tests/models/wav2vec2/test_modeling_wav2vec2.py
This commit is contained in:
Patrick von Platen
2023-08-02 18:11:15 +02:00
committed by GitHub
parent ad8321512d
commit b28ebb2655
2 changed files with 41 additions and 1 deletions

View File

@@ -1151,6 +1151,43 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
# test that loading adapter weights with mismatched vocab sizes can be loaded
def test_load_target_lang_with_mismatched_size(self):
processor = Wav2Vec2Processor.from_pretrained(
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
)
def get_logits(model, input_features):
model = model.to(torch_device)
batch = processor(
input_features,
padding=True,
sampling_rate=processor.feature_extractor.sampling_rate,
return_tensors="pt",
)
with torch.no_grad():
logits = model(
input_values=batch["input_values"].to(torch_device),
attention_mask=batch["attention_mask"].to(torch_device),
).logits
return logits
input_features = [np.random.random(16_000 * s) for s in [1, 3, 2, 6]]
model = Wav2Vec2ForCTC.from_pretrained(
"hf-internal-testing/tiny-random-wav2vec2-adapter", target_lang="fr", ignore_mismatched_sizes=True
)
logits = get_logits(model, input_features)
model_2 = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter")
model_2.load_adapter("fr")
logits_2 = get_logits(model_2, input_features)
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
def test_load_attn_adapter(self):
processor = Wav2Vec2Processor.from_pretrained(
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True