[Wav2Vec2] Fix torch srcipt (#24062)
* [Wav2Vec2] Fix torch srcipt * fix more
This commit is contained in:
committed by
GitHub
parent
612b2a1a6d
commit
52972e70c7
@@ -1178,8 +1178,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)):
|
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)):
|
||||||
module.gradient_checkpointing = value
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
@property
|
def _get_adapters(self):
|
||||||
def _adapters(self):
|
|
||||||
if self.config.adapter_attn_dim is None:
|
if self.config.adapter_attn_dim is None:
|
||||||
raise ValueError(f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`.")
|
raise ValueError(f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`.")
|
||||||
|
|
||||||
@@ -1339,7 +1338,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
f" directory containing a file named {filepath}."
|
f" directory containing a file named {filepath}."
|
||||||
)
|
)
|
||||||
|
|
||||||
adapter_weights = self._adapters
|
adapter_weights = self._get_adapters()
|
||||||
unexpected_keys = set(state_dict.keys()) - set(adapter_weights.keys())
|
unexpected_keys = set(state_dict.keys()) - set(adapter_weights.keys())
|
||||||
missing_keys = set(adapter_weights.keys()) - set(state_dict.keys())
|
missing_keys = set(adapter_weights.keys()) - set(state_dict.keys())
|
||||||
|
|
||||||
|
|||||||
@@ -297,7 +297,7 @@ class Wav2Vec2ModelTester:
|
|||||||
config.adapter_attn_dim = 16
|
config.adapter_attn_dim = 16
|
||||||
model = Wav2Vec2ForCTC(config=config)
|
model = Wav2Vec2ForCTC(config=config)
|
||||||
|
|
||||||
self.parent.assertIsNotNone(model._adapters)
|
self.parent.assertIsNotNone(model._get_adapters())
|
||||||
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -1146,7 +1146,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model = Wav2Vec2ForCTC.from_pretrained(tempdir)
|
model = Wav2Vec2ForCTC.from_pretrained(tempdir)
|
||||||
|
|
||||||
logits = get_logits(model, input_features)
|
logits = get_logits(model, input_features)
|
||||||
adapter_weights = model._adapters
|
adapter_weights = model._get_adapters()
|
||||||
|
|
||||||
# save safe weights
|
# save safe weights
|
||||||
safe_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_SAFE_FILE.format("eng"))
|
safe_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_SAFE_FILE.format("eng"))
|
||||||
@@ -1168,7 +1168,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model = Wav2Vec2ForCTC.from_pretrained(tempdir)
|
model = Wav2Vec2ForCTC.from_pretrained(tempdir)
|
||||||
|
|
||||||
logits = get_logits(model, input_features)
|
logits = get_logits(model, input_features)
|
||||||
adapter_weights = model._adapters
|
adapter_weights = model._get_adapters()
|
||||||
|
|
||||||
# save pt weights
|
# save pt weights
|
||||||
pt_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_PT_FILE.format("eng"))
|
pt_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_PT_FILE.format("eng"))
|
||||||
|
|||||||
Reference in New Issue
Block a user