From b28ebb265590ff12bb888b3ca5e07d56922beb6a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 2 Aug 2023 18:11:15 +0200 Subject: [PATCH] [MMS] Fix mms (#25267) * [MMS] Fix mms * [MMS] Fix mms * fix mms loading * Apply suggestions from code review * make style * Update tests/models/wav2vec2/test_modeling_wav2vec2.py --- src/transformers/modeling_utils.py | 5 ++- .../models/wav2vec2/test_modeling_wav2vec2.py | 37 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7c39733be4..9e167cfdee 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3047,6 +3047,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix offload_state_dict = True is_sharded_safetensors = is_safetensors and sharded_metadata is not None + + # tie the model weights before retrieving the state_dict + model.tie_weights() + # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) @@ -3092,7 +3096,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix model_buffers = {".".join([prefix, key]) for key in model_buffers} unexpected_keys = list(unexpected_keys - model_buffers) - model.tie_weights() if device_map is None: ptrs = collections.defaultdict(list) for name, tensor in model.state_dict().items(): diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index 4db9b156db..630a5d8e85 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -1151,6 +1151,43 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3)) + # test that loading adapter weights with mismatched vocab sizes can be loaded + def test_load_target_lang_with_mismatched_size(self): + processor = Wav2Vec2Processor.from_pretrained( + "hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True + ) + + def get_logits(model, input_features): + model = model.to(torch_device) + batch = processor( + input_features, + padding=True, + sampling_rate=processor.feature_extractor.sampling_rate, + return_tensors="pt", + ) + + with torch.no_grad(): + logits = model( + input_values=batch["input_values"].to(torch_device), + attention_mask=batch["attention_mask"].to(torch_device), + ).logits + return logits + + input_features = [np.random.random(16_000 * s) for s in [1, 3, 2, 6]] + + model = Wav2Vec2ForCTC.from_pretrained( + "hf-internal-testing/tiny-random-wav2vec2-adapter", target_lang="fr", ignore_mismatched_sizes=True + ) + + logits = get_logits(model, input_features) + + model_2 = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter") + model_2.load_adapter("fr") + + logits_2 = get_logits(model_2, input_features) + + self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3)) + def test_load_attn_adapter(self): processor = Wav2Vec2Processor.from_pretrained( "hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True