[Test] Fix W2V-Conformer integration test (#17303)
* [Test] Fix W2V-Conformer integration test * correct w2v2 * up
This commit is contained in:
committed by
GitHub
parent
28a0811652
commit
10704e1209
@@ -1414,7 +1414,6 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
|
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
|
||||||
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
||||||
>>> from datasets import load_dataset
|
>>> from datasets import load_dataset
|
||||||
>>> import soundfile as sf
|
|
||||||
|
|
||||||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
|
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
|
||||||
>>> model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
|
>>> model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
|
||||||
|
|||||||
@@ -1442,7 +1442,7 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2-base->wav2vec2-conformer-rel-pos-large,wav2vec2->wav2vec2_conformer
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_values: Optional[torch.Tensor],
|
input_values: Optional[torch.Tensor],
|
||||||
@@ -1470,14 +1470,9 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
|
|||||||
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
|
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
|
||||||
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices
|
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices
|
||||||
>>> from datasets import load_dataset
|
>>> from datasets import load_dataset
|
||||||
>>> import soundfile as sf
|
|
||||||
|
|
||||||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained(
|
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
||||||
... "facebook/wav2vec2_conformer-conformer-rel-pos-large"
|
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
||||||
... )
|
|
||||||
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained(
|
|
||||||
... "facebook/wav2vec2_conformer-conformer-rel-pos-large"
|
|
||||||
... )
|
|
||||||
|
|
||||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
|
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
|
||||||
|
|||||||
@@ -581,6 +581,10 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
module.weight_v.data.fill_(3)
|
module.weight_v.data.fill_(3)
|
||||||
if hasattr(module, "bias") and module.bias is not None:
|
if hasattr(module, "bias") and module.bias is not None:
|
||||||
module.bias.data.fill_(3)
|
module.bias.data.fill_(3)
|
||||||
|
if hasattr(module, "pos_bias_u") and module.pos_bias_u is not None:
|
||||||
|
module.pos_bias_u.data.fill_(3)
|
||||||
|
if hasattr(module, "pos_bias_v") and module.pos_bias_v is not None:
|
||||||
|
module.pos_bias_v.data.fill_(3)
|
||||||
if hasattr(module, "codevectors") and module.codevectors is not None:
|
if hasattr(module, "codevectors") and module.codevectors is not None:
|
||||||
module.codevectors.data.fill_(3)
|
module.codevectors.data.fill_(3)
|
||||||
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user