[Wav2Vec2 - MMS] Correct directly loading adapters weights (#24335)
* Correct direct lang loading * correct more * revert black * Use tie weights instead= * add tests * add tests * make style
This commit is contained in:
committed by
GitHub
parent
e5c760d636
commit
b0513b013b
@@ -1117,6 +1117,40 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
def test_load_and_set_attn_adapter(self):
|
||||
processor = Wav2Vec2Processor.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
|
||||
)
|
||||
|
||||
def get_logits(model, input_features):
|
||||
model = model.to(torch_device)
|
||||
batch = processor(
|
||||
input_features,
|
||||
padding=True,
|
||||
sampling_rate=processor.feature_extractor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(
|
||||
input_values=batch["input_values"].to(torch_device),
|
||||
attention_mask=batch["attention_mask"].to(torch_device),
|
||||
).logits
|
||||
return logits
|
||||
|
||||
input_features = [np.random.random(16_000 * s) for s in [1, 3, 2, 6]]
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter", target_lang="it")
|
||||
|
||||
logits = get_logits(model, input_features)
|
||||
|
||||
model_2 = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter")
|
||||
model_2.load_adapter("it")
|
||||
|
||||
logits_2 = get_logits(model_2, input_features)
|
||||
|
||||
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
|
||||
|
||||
def test_load_attn_adapter(self):
|
||||
processor = Wav2Vec2Processor.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
|
||||
|
||||
Reference in New Issue
Block a user