From 10704e12094b09a069bb4375a422c83a3c4f44b1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 17 May 2022 18:20:36 +0200 Subject: [PATCH] [Test] Fix W2V-Conformer integration test (#17303) * [Test] Fix W2V-Conformer integration test * correct w2v2 * up --- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 1 - .../wav2vec2_conformer/modeling_wav2vec2_conformer.py | 11 +++-------- .../test_modeling_wav2vec2_conformer.py | 4 ++++ 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 708e007698..06e91446c4 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1414,7 +1414,6 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): >>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining >>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices >>> from datasets import load_dataset - >>> import soundfile as sf >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base") >>> model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base") diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 40edd83679..e79224c077 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -1442,7 +1442,7 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2-base->wav2vec2-conformer-rel-pos-large,wav2vec2->wav2vec2_conformer + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large def forward( self, input_values: Optional[torch.Tensor], @@ -1470,14 +1470,9 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices >>> from datasets import load_dataset - >>> import soundfile as sf - >>> feature_extractor = AutoFeatureExtractor.from_pretrained( - ... "facebook/wav2vec2_conformer-conformer-rel-pos-large" - ... ) - >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained( - ... "facebook/wav2vec2_conformer-conformer-rel-pos-large" - ... ) + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") + >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1 diff --git a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py index a3d6a91b76..cb2719a591 100644 --- a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py +++ b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py @@ -581,6 +581,10 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, unittest.TestCase): module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: module.bias.data.fill_(3) + if hasattr(module, "pos_bias_u") and module.pos_bias_u is not None: + module.pos_bias_u.data.fill_(3) + if hasattr(module, "pos_bias_v") and module.pos_bias_v is not None: + module.pos_bias_v.data.fill_(3) if hasattr(module, "codevectors") and module.codevectors is not None: module.codevectors.data.fill_(3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: