From b0513b013b10939a2b47ab94933c2cca909716a2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Jun 2023 19:39:52 +0200 Subject: [PATCH] [Wav2Vec2 - MMS] Correct directly loading adapters weights (#24335) * Correct direct lang loading * correct more * revert black * Use tie weights instead= * add tests * add tests * make style --- .../models/hubert/modeling_hubert.py | 26 ++++++++++--- src/transformers/models/sew/modeling_sew.py | 26 ++++++++++--- .../models/sew_d/modeling_sew_d.py | 26 ++++++++++--- .../models/unispeech/modeling_unispeech.py | 26 ++++++++++--- .../unispeech_sat/modeling_unispeech_sat.py | 26 ++++++++++--- .../models/wav2vec2/modeling_wav2vec2.py | 37 ++++++++++++++++--- .../modeling_wav2vec2_conformer.py | 9 +---- .../models/wavlm/modeling_wavlm.py | 26 ++++++++++--- .../models/wav2vec2/test_modeling_wav2vec2.py | 34 +++++++++++++++++ 9 files changed, 193 insertions(+), 43 deletions(-) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index c7436c5ea4..70a8c07940 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -1134,12 +1134,14 @@ class HubertModel(HubertPreTrainedModel): ) # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT class HubertForCTC(HubertPreTrainedModel): - def __init__(self, config, target_lang=None): + def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) self.hubert = HubertModel(config) self.dropout = nn.Dropout(config.final_dropout) + self.target_lang = target_lang + if config.vocab_size is None: raise ValueError( f"You are trying to instantiate {self.__class__} with a configuration that " @@ -1152,15 +1154,29 @@ class HubertForCTC(HubertPreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when + passing `target_lang=...` to `from_pretrained(...)`. + + This method is **not** supposed to be called by the user and is prone to be changed in the future. + """ + + # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to + # correctly load adapter layers for Hubert so that we do not have to introduce a new API to + # [`PreTrainedModel`]. While slightly hacky, Hubert never has to tie input and output embeddings, so that it is + # ok to repurpose this function here. + target_lang = self.target_lang + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: logger.info("By default `target_lang` is set to 'eng'.") elif target_lang is not None: - self.load_adapter(target_lang) - - # Initialize weights and apply final processing - self.post_init() + self.load_adapter(target_lang, force_load=True) def freeze_feature_extractor(self): """ diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 928b2fd3ad..dd854c49f5 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -969,12 +969,14 @@ class SEWModel(SEWPreTrainedModel): ) # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW class SEWForCTC(SEWPreTrainedModel): - def __init__(self, config, target_lang=None): + def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) self.sew = SEWModel(config) self.dropout = nn.Dropout(config.final_dropout) + self.target_lang = target_lang + if config.vocab_size is None: raise ValueError( f"You are trying to instantiate {self.__class__} with a configuration that " @@ -987,15 +989,29 @@ class SEWForCTC(SEWPreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when + passing `target_lang=...` to `from_pretrained(...)`. + + This method is **not** supposed to be called by the user and is prone to be changed in the future. + """ + + # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to + # correctly load adapter layers for SEW so that we do not have to introduce a new API to + # [`PreTrainedModel`]. While slightly hacky, SEW never has to tie input and output embeddings, so that it is + # ok to repurpose this function here. + target_lang = self.target_lang + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: logger.info("By default `target_lang` is set to 'eng'.") elif target_lang is not None: - self.load_adapter(target_lang) - - # Initialize weights and apply final processing - self.post_init() + self.load_adapter(target_lang, force_load=True) def freeze_feature_extractor(self): """ diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 403c865cdd..7f7c1977d6 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -1509,12 +1509,14 @@ class SEWDModel(SEWDPreTrainedModel): ) # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD class SEWDForCTC(SEWDPreTrainedModel): - def __init__(self, config, target_lang=None): + def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) self.sew_d = SEWDModel(config) self.dropout = nn.Dropout(config.final_dropout) + self.target_lang = target_lang + if config.vocab_size is None: raise ValueError( f"You are trying to instantiate {self.__class__} with a configuration that " @@ -1527,15 +1529,29 @@ class SEWDForCTC(SEWDPreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when + passing `target_lang=...` to `from_pretrained(...)`. + + This method is **not** supposed to be called by the user and is prone to be changed in the future. + """ + + # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to + # correctly load adapter layers for SEWD so that we do not have to introduce a new API to + # [`PreTrainedModel`]. While slightly hacky, SEWD never has to tie input and output embeddings, so that it is + # ok to repurpose this function here. + target_lang = self.target_lang + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: logger.info("By default `target_lang` is set to 'eng'.") elif target_lang is not None: - self.load_adapter(target_lang) - - # Initialize weights and apply final processing - self.post_init() + self.load_adapter(target_lang, force_load=True) def freeze_feature_extractor(self): """ diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index da81bc4b14..e068fa59e5 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -1378,12 +1378,14 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel): ) # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeech, wav2vec2->unispeech, WAV_2_VEC_2->UNISPEECH class UniSpeechForCTC(UniSpeechPreTrainedModel): - def __init__(self, config, target_lang=None): + def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) self.unispeech = UniSpeechModel(config) self.dropout = nn.Dropout(config.final_dropout) + self.target_lang = target_lang + if config.vocab_size is None: raise ValueError( f"You are trying to instantiate {self.__class__} with a configuration that " @@ -1396,15 +1398,29 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when + passing `target_lang=...` to `from_pretrained(...)`. + + This method is **not** supposed to be called by the user and is prone to be changed in the future. + """ + + # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to + # correctly load adapter layers for UniSpeech so that we do not have to introduce a new API to + # [`PreTrainedModel`]. While slightly hacky, UniSpeech never has to tie input and output embeddings, so that it is + # ok to repurpose this function here. + target_lang = self.target_lang + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: logger.info("By default `target_lang` is set to 'eng'.") elif target_lang is not None: - self.load_adapter(target_lang) - - # Initialize weights and apply final processing - self.post_init() + self.load_adapter(target_lang, force_load=True) def freeze_feature_extractor(self): """ diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 58ba244ade..2ed8a5d572 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -1385,12 +1385,14 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel): ) # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel): - def __init__(self, config, target_lang=None): + def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) self.unispeech_sat = UniSpeechSatModel(config) self.dropout = nn.Dropout(config.final_dropout) + self.target_lang = target_lang + if config.vocab_size is None: raise ValueError( f"You are trying to instantiate {self.__class__} with a configuration that " @@ -1403,15 +1405,29 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when + passing `target_lang=...` to `from_pretrained(...)`. + + This method is **not** supposed to be called by the user and is prone to be changed in the future. + """ + + # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to + # correctly load adapter layers for UniSpeechSat so that we do not have to introduce a new API to + # [`PreTrainedModel`]. While slightly hacky, UniSpeechSat never has to tie input and output embeddings, so that it is + # ok to repurpose this function here. + target_lang = self.target_lang + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: logger.info("By default `target_lang` is set to 'eng'.") elif target_lang is not None: - self.load_adapter(target_lang) - - # Initialize weights and apply final processing - self.post_init() + self.load_adapter(target_lang, force_load=True) def freeze_feature_extractor(self): """ diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 26cfe3c6cb..43ab2408bb 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1207,7 +1207,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): if isinstance(self, Wav2Vec2ForCTC): self._init_weights(self.lm_head) - def load_adapter(self, target_lang: str, **kwargs): + def load_adapter(self, target_lang: str, force_load=True, **kwargs): r""" Load a language adapter model from a pre-trained adapter model. @@ -1215,6 +1215,8 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): target_lang (`str`): Has to be a language id of an existing adapter weight. Adapter weights are stored in the format adapter..safetensors or adapter..bin + force_load (`bool`, defaults to `True`): + Whether the weights shall be loaded even if `target_lang` matches `self.target_lang`. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. @@ -1271,6 +1273,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): if self.config.adapter_attn_dim is None: raise ValueError(f"Cannot load_adapter for {target_lang} if `config.adapter_attn_dim` is not defined.") + if target_lang == self.target_lang and not force_load: + logger.warn(f"Adapter weights are already set to {target_lang}.") + return + cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) @@ -1372,6 +1378,9 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): state_dict = {k: v.to(adapter_weights[k]) for k, v in state_dict.items()} self.load_state_dict(state_dict, strict=False) + # set target language corectly + self.target_lang = target_lang + WAV_2_VEC_2_START_DOCSTRING = r""" Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech @@ -1854,12 +1863,14 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): WAV_2_VEC_2_START_DOCSTRING, ) class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): - def __init__(self, config, target_lang=None): + def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) self.wav2vec2 = Wav2Vec2Model(config) self.dropout = nn.Dropout(config.final_dropout) + self.target_lang = target_lang + if config.vocab_size is None: raise ValueError( f"You are trying to instantiate {self.__class__} with a configuration that " @@ -1872,15 +1883,29 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when + passing `target_lang=...` to `from_pretrained(...)`. + + This method is **not** supposed to be called by the user and is prone to be changed in the future. + """ + + # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to + # correctly load adapter layers for Wav2Vec2 so that we do not have to introduce a new API to + # [`PreTrainedModel`]. While slightly hacky, Wav2Vec2 never has to tie input and output embeddings, so that it is + # ok to repurpose this function here. + target_lang = self.target_lang + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: logger.info("By default `target_lang` is set to 'eng'.") elif target_lang is not None: - self.load_adapter(target_lang) - - # Initialize weights and apply final processing - self.post_init() + self.load_adapter(target_lang, force_load=True) def freeze_feature_extractor(self): """ diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 4eba3023de..3e37a4a504 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -1610,6 +1610,8 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) self.dropout = nn.Dropout(config.final_dropout) + self.target_lang = target_lang + if config.vocab_size is None: raise ValueError( f"You are trying to instantiate {self.__class__} with a configuration that " @@ -1622,13 +1624,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) - if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: - raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") - elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: - logger.info("By default `target_lang` is set to 'eng'.") - elif target_lang is not None: - self.load_adapter(target_lang) - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 4181794ce1..e4072d9372 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -1272,12 +1272,14 @@ class WavLMModel(WavLMPreTrainedModel): ) # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM class WavLMForCTC(WavLMPreTrainedModel): - def __init__(self, config, target_lang=None): + def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) self.wavlm = WavLMModel(config) self.dropout = nn.Dropout(config.final_dropout) + self.target_lang = target_lang + if config.vocab_size is None: raise ValueError( f"You are trying to instantiate {self.__class__} with a configuration that " @@ -1290,15 +1292,29 @@ class WavLMForCTC(WavLMPreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when + passing `target_lang=...` to `from_pretrained(...)`. + + This method is **not** supposed to be called by the user and is prone to be changed in the future. + """ + + # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to + # correctly load adapter layers for WavLM so that we do not have to introduce a new API to + # [`PreTrainedModel`]. While slightly hacky, WavLM never has to tie input and output embeddings, so that it is + # ok to repurpose this function here. + target_lang = self.target_lang + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: logger.info("By default `target_lang` is set to 'eng'.") elif target_lang is not None: - self.load_adapter(target_lang) - - # Initialize weights and apply final processing - self.post_init() + self.load_adapter(target_lang, force_load=True) def freeze_feature_extractor(self): """ diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index 87206a4b9b..65bfcb4451 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -1117,6 +1117,40 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): def test_feed_forward_chunking(self): pass + def test_load_and_set_attn_adapter(self): + processor = Wav2Vec2Processor.from_pretrained( + "hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True + ) + + def get_logits(model, input_features): + model = model.to(torch_device) + batch = processor( + input_features, + padding=True, + sampling_rate=processor.feature_extractor.sampling_rate, + return_tensors="pt", + ) + + with torch.no_grad(): + logits = model( + input_values=batch["input_values"].to(torch_device), + attention_mask=batch["attention_mask"].to(torch_device), + ).logits + return logits + + input_features = [np.random.random(16_000 * s) for s in [1, 3, 2, 6]] + + model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter", target_lang="it") + + logits = get_logits(model, input_features) + + model_2 = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter") + model_2.load_adapter("it") + + logits_2 = get_logits(model_2, input_features) + + self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3)) + def test_load_attn_adapter(self): processor = Wav2Vec2Processor.from_pretrained( "hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True