Add DistilHuBERT (#14174)
* Add conversion * Rename * Add an integration test and remove layer_norm * Remove layer_norm from the converter * wording * Fix imports
This commit is contained in:
@@ -760,3 +760,46 @@ class HubertModelIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(predicted_ids.tolist(), expected_labels)
|
||||
# TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572
|
||||
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-1))
|
||||
|
||||
def test_inference_distilhubert(self):
|
||||
model = HubertModel.from_pretrained("anton-l/distilhubert").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/distilhubert")
|
||||
|
||||
# TODO: can't test on batched inputs due to incompatible padding https://github.com/pytorch/fairseq/pull/3572
|
||||
input_speech = self._load_datasamples(1)
|
||||
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values).last_hidden_state
|
||||
|
||||
# expected outputs taken from the original SEW implementation
|
||||
expected_outputs_first = torch.tensor(
|
||||
[
|
||||
[
|
||||
[-0.3505, 0.1167, 0.0608, 0.1294],
|
||||
[-0.3085, 0.0481, 0.1106, 0.0955],
|
||||
[-0.3107, -0.0391, 0.0739, 0.1360],
|
||||
[-0.2385, -0.1795, -0.0928, 0.2389],
|
||||
]
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
expected_outputs_last = torch.tensor(
|
||||
[
|
||||
[
|
||||
[-0.0732, 0.0255, 0.0529, -0.1372],
|
||||
[-0.0812, 0.1259, 0.0564, -0.0438],
|
||||
[-0.0054, 0.0758, -0.0002, -0.1617],
|
||||
[0.0133, -0.0320, -0.0687, 0.0062],
|
||||
]
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
expected_output_sum = -3776.0730
|
||||
|
||||
self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
|
||||
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
|
||||
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 0.1)
|
||||
|
||||
Reference in New Issue
Block a user