From 52972e70c7b95ad8c80572fc50d843ee4697b7f6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 7 Jun 2023 13:27:07 +0200 Subject: [PATCH] [Wav2Vec2] Fix torch srcipt (#24062) * [Wav2Vec2] Fix torch srcipt * fix more --- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 5 ++--- tests/models/wav2vec2/test_modeling_wav2vec2.py | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 8b4d7874c6..80d05d3777 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1178,8 +1178,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)): module.gradient_checkpointing = value - @property - def _adapters(self): + def _get_adapters(self): if self.config.adapter_attn_dim is None: raise ValueError(f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`.") @@ -1339,7 +1338,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): f" directory containing a file named {filepath}." ) - adapter_weights = self._adapters + adapter_weights = self._get_adapters() unexpected_keys = set(state_dict.keys()) - set(adapter_weights.keys()) missing_keys = set(adapter_weights.keys()) - set(state_dict.keys()) diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index 8fc82eb96e..cf41dd9a30 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -297,7 +297,7 @@ class Wav2Vec2ModelTester: config.adapter_attn_dim = 16 model = Wav2Vec2ForCTC(config=config) - self.parent.assertIsNotNone(model._adapters) + self.parent.assertIsNotNone(model._get_adapters()) model.to(torch_device) model.eval() @@ -1146,7 +1146,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): model = Wav2Vec2ForCTC.from_pretrained(tempdir) logits = get_logits(model, input_features) - adapter_weights = model._adapters + adapter_weights = model._get_adapters() # save safe weights safe_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_SAFE_FILE.format("eng")) @@ -1168,7 +1168,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): model = Wav2Vec2ForCTC.from_pretrained(tempdir) logits = get_logits(model, input_features) - adapter_weights = model._adapters + adapter_weights = model._get_adapters() # save pt weights pt_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_PT_FILE.format("eng"))