From bc6f51e53925e118f04a76f9a59a3d16fc9684a2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 9 Jun 2021 20:41:59 +0100 Subject: [PATCH] [Wav2Vec2ForPretraining] Correct checkpoints wav2vec2 & fix tests (#12089) * fix_torch_device_generate_test * remove @ * fix tests --- tests/test_modeling_wav2vec2.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index 0934967dc2..f9fa91a476 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -349,6 +349,8 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): module.bias.data.fill_(3) if hasattr(module, "codevectors") and module.codevectors is not None: module.codevectors.data.fill_(3) + if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: + module.masked_spec_embed.data.fill_(3) @slow def test_model_from_pretrained(self): @@ -487,6 +489,8 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): module.bias.data.fill_(3) if hasattr(module, "codevectors") and module.codevectors is not None: module.codevectors.data.fill_(3) + if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: + module.masked_spec_embed.data.fill_(3) def test_model_for_pretraining(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -677,10 +681,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) def test_inference_integration(self): - model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base") + model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base") model.to(torch_device) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( - "patrickvonplaten/wav2vec2-base", return_attention_mask=True + "facebook/wav2vec2-base", return_attention_mask=True ) input_speech = self._load_datasamples(2) @@ -723,10 +727,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3)) def test_inference_pretrained(self): - model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base") + model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base") model.to(torch_device) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( - "patrickvonplaten/wav2vec2-base", return_attention_mask=True + "facebook/wav2vec2-base", return_attention_mask=True ) input_speech = self._load_datasamples(2) @@ -761,7 +765,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): # ... now compare to randomly initialized model - config = Wav2Vec2Config.from_pretrained("patrickvonplaten/wav2vec2-base") + config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-base") model_rand = Wav2Vec2ForPreTraining(config).to(torch_device).eval() with torch.no_grad(): @@ -785,9 +789,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): # => the cosine similarity between quantized states and predicted states is very likely < 0.1 self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0) + @unittest.skipIf(torch_device != "cpu", "cannot make deterministic on GPU") def test_loss_pretraining(self): model = Wav2Vec2ForPreTraining.from_pretrained( - "patrickvonplaten/wav2vec2-base", + "facebook/wav2vec2-base", attention_dropout=0.0, feat_proj_dropout=0.0, hidden_dropout=0.0, @@ -796,7 +801,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): model.to(torch_device).train() feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( - "patrickvonplaten/wav2vec2-base", return_attention_mask=True + "facebook/wav2vec2-base", return_attention_mask=True ) input_speech = self._load_datasamples(2) @@ -829,6 +834,6 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3) # check overall loss (contrastive loss + diversity loss) - expected_loss = 62.5170 if model.device.type == "cpu" else 50.3612 + expected_loss = 62.5170 self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)