[Wav2Vec2 - MMS] Correct directly loading adapters weights (#24335)
* Correct direct lang loading * correct more * revert black * Use tie weights instead= * add tests * add tests * make style
This commit is contained in:
committed by
GitHub
parent
e5c760d636
commit
b0513b013b
@@ -1134,12 +1134,14 @@ class HubertModel(HubertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
|
||||||
class HubertForCTC(HubertPreTrainedModel):
|
class HubertForCTC(HubertPreTrainedModel):
|
||||||
def __init__(self, config, target_lang=None):
|
def __init__(self, config, target_lang: Optional[str] = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.hubert = HubertModel(config)
|
self.hubert = HubertModel(config)
|
||||||
self.dropout = nn.Dropout(config.final_dropout)
|
self.dropout = nn.Dropout(config.final_dropout)
|
||||||
|
|
||||||
|
self.target_lang = target_lang
|
||||||
|
|
||||||
if config.vocab_size is None:
|
if config.vocab_size is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||||
@@ -1152,15 +1154,29 @@ class HubertForCTC(HubertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
"""
|
||||||
|
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||||||
|
passing `target_lang=...` to `from_pretrained(...)`.
|
||||||
|
|
||||||
|
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||||||
|
# correctly load adapter layers for Hubert so that we do not have to introduce a new API to
|
||||||
|
# [`PreTrainedModel`]. While slightly hacky, Hubert never has to tie input and output embeddings, so that it is
|
||||||
|
# ok to repurpose this function here.
|
||||||
|
target_lang = self.target_lang
|
||||||
|
|
||||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||||
logger.info("By default `target_lang` is set to 'eng'.")
|
logger.info("By default `target_lang` is set to 'eng'.")
|
||||||
elif target_lang is not None:
|
elif target_lang is not None:
|
||||||
self.load_adapter(target_lang)
|
self.load_adapter(target_lang, force_load=True)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def freeze_feature_extractor(self):
|
def freeze_feature_extractor(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -969,12 +969,14 @@ class SEWModel(SEWPreTrainedModel):
|
|||||||
)
|
)
|
||||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW
|
||||||
class SEWForCTC(SEWPreTrainedModel):
|
class SEWForCTC(SEWPreTrainedModel):
|
||||||
def __init__(self, config, target_lang=None):
|
def __init__(self, config, target_lang: Optional[str] = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.sew = SEWModel(config)
|
self.sew = SEWModel(config)
|
||||||
self.dropout = nn.Dropout(config.final_dropout)
|
self.dropout = nn.Dropout(config.final_dropout)
|
||||||
|
|
||||||
|
self.target_lang = target_lang
|
||||||
|
|
||||||
if config.vocab_size is None:
|
if config.vocab_size is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||||
@@ -987,15 +989,29 @@ class SEWForCTC(SEWPreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
"""
|
||||||
|
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||||||
|
passing `target_lang=...` to `from_pretrained(...)`.
|
||||||
|
|
||||||
|
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||||||
|
# correctly load adapter layers for SEW so that we do not have to introduce a new API to
|
||||||
|
# [`PreTrainedModel`]. While slightly hacky, SEW never has to tie input and output embeddings, so that it is
|
||||||
|
# ok to repurpose this function here.
|
||||||
|
target_lang = self.target_lang
|
||||||
|
|
||||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||||
logger.info("By default `target_lang` is set to 'eng'.")
|
logger.info("By default `target_lang` is set to 'eng'.")
|
||||||
elif target_lang is not None:
|
elif target_lang is not None:
|
||||||
self.load_adapter(target_lang)
|
self.load_adapter(target_lang, force_load=True)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def freeze_feature_extractor(self):
|
def freeze_feature_extractor(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1509,12 +1509,14 @@ class SEWDModel(SEWDPreTrainedModel):
|
|||||||
)
|
)
|
||||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD
|
||||||
class SEWDForCTC(SEWDPreTrainedModel):
|
class SEWDForCTC(SEWDPreTrainedModel):
|
||||||
def __init__(self, config, target_lang=None):
|
def __init__(self, config, target_lang: Optional[str] = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.sew_d = SEWDModel(config)
|
self.sew_d = SEWDModel(config)
|
||||||
self.dropout = nn.Dropout(config.final_dropout)
|
self.dropout = nn.Dropout(config.final_dropout)
|
||||||
|
|
||||||
|
self.target_lang = target_lang
|
||||||
|
|
||||||
if config.vocab_size is None:
|
if config.vocab_size is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||||
@@ -1527,15 +1529,29 @@ class SEWDForCTC(SEWDPreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
"""
|
||||||
|
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||||||
|
passing `target_lang=...` to `from_pretrained(...)`.
|
||||||
|
|
||||||
|
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||||||
|
# correctly load adapter layers for SEWD so that we do not have to introduce a new API to
|
||||||
|
# [`PreTrainedModel`]. While slightly hacky, SEWD never has to tie input and output embeddings, so that it is
|
||||||
|
# ok to repurpose this function here.
|
||||||
|
target_lang = self.target_lang
|
||||||
|
|
||||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||||
logger.info("By default `target_lang` is set to 'eng'.")
|
logger.info("By default `target_lang` is set to 'eng'.")
|
||||||
elif target_lang is not None:
|
elif target_lang is not None:
|
||||||
self.load_adapter(target_lang)
|
self.load_adapter(target_lang, force_load=True)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def freeze_feature_extractor(self):
|
def freeze_feature_extractor(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1378,12 +1378,14 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
|
|||||||
)
|
)
|
||||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeech, wav2vec2->unispeech, WAV_2_VEC_2->UNISPEECH
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeech, wav2vec2->unispeech, WAV_2_VEC_2->UNISPEECH
|
||||||
class UniSpeechForCTC(UniSpeechPreTrainedModel):
|
class UniSpeechForCTC(UniSpeechPreTrainedModel):
|
||||||
def __init__(self, config, target_lang=None):
|
def __init__(self, config, target_lang: Optional[str] = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.unispeech = UniSpeechModel(config)
|
self.unispeech = UniSpeechModel(config)
|
||||||
self.dropout = nn.Dropout(config.final_dropout)
|
self.dropout = nn.Dropout(config.final_dropout)
|
||||||
|
|
||||||
|
self.target_lang = target_lang
|
||||||
|
|
||||||
if config.vocab_size is None:
|
if config.vocab_size is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||||
@@ -1396,15 +1398,29 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
"""
|
||||||
|
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||||||
|
passing `target_lang=...` to `from_pretrained(...)`.
|
||||||
|
|
||||||
|
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||||||
|
# correctly load adapter layers for UniSpeech so that we do not have to introduce a new API to
|
||||||
|
# [`PreTrainedModel`]. While slightly hacky, UniSpeech never has to tie input and output embeddings, so that it is
|
||||||
|
# ok to repurpose this function here.
|
||||||
|
target_lang = self.target_lang
|
||||||
|
|
||||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||||
logger.info("By default `target_lang` is set to 'eng'.")
|
logger.info("By default `target_lang` is set to 'eng'.")
|
||||||
elif target_lang is not None:
|
elif target_lang is not None:
|
||||||
self.load_adapter(target_lang)
|
self.load_adapter(target_lang, force_load=True)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def freeze_feature_extractor(self):
|
def freeze_feature_extractor(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1385,12 +1385,14 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
|
|||||||
)
|
)
|
||||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
|
||||||
class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
|
class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
|
||||||
def __init__(self, config, target_lang=None):
|
def __init__(self, config, target_lang: Optional[str] = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.unispeech_sat = UniSpeechSatModel(config)
|
self.unispeech_sat = UniSpeechSatModel(config)
|
||||||
self.dropout = nn.Dropout(config.final_dropout)
|
self.dropout = nn.Dropout(config.final_dropout)
|
||||||
|
|
||||||
|
self.target_lang = target_lang
|
||||||
|
|
||||||
if config.vocab_size is None:
|
if config.vocab_size is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||||
@@ -1403,15 +1405,29 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
"""
|
||||||
|
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||||||
|
passing `target_lang=...` to `from_pretrained(...)`.
|
||||||
|
|
||||||
|
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||||||
|
# correctly load adapter layers for UniSpeechSat so that we do not have to introduce a new API to
|
||||||
|
# [`PreTrainedModel`]. While slightly hacky, UniSpeechSat never has to tie input and output embeddings, so that it is
|
||||||
|
# ok to repurpose this function here.
|
||||||
|
target_lang = self.target_lang
|
||||||
|
|
||||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||||
logger.info("By default `target_lang` is set to 'eng'.")
|
logger.info("By default `target_lang` is set to 'eng'.")
|
||||||
elif target_lang is not None:
|
elif target_lang is not None:
|
||||||
self.load_adapter(target_lang)
|
self.load_adapter(target_lang, force_load=True)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def freeze_feature_extractor(self):
|
def freeze_feature_extractor(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1207,7 +1207,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
if isinstance(self, Wav2Vec2ForCTC):
|
if isinstance(self, Wav2Vec2ForCTC):
|
||||||
self._init_weights(self.lm_head)
|
self._init_weights(self.lm_head)
|
||||||
|
|
||||||
def load_adapter(self, target_lang: str, **kwargs):
|
def load_adapter(self, target_lang: str, force_load=True, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Load a language adapter model from a pre-trained adapter model.
|
Load a language adapter model from a pre-trained adapter model.
|
||||||
|
|
||||||
@@ -1215,6 +1215,8 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
target_lang (`str`):
|
target_lang (`str`):
|
||||||
Has to be a language id of an existing adapter weight. Adapter weights are stored in the format
|
Has to be a language id of an existing adapter weight. Adapter weights are stored in the format
|
||||||
adapter.<lang>.safetensors or adapter.<lang>.bin
|
adapter.<lang>.safetensors or adapter.<lang>.bin
|
||||||
|
force_load (`bool`, defaults to `True`):
|
||||||
|
Whether the weights shall be loaded even if `target_lang` matches `self.target_lang`.
|
||||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||||
standard cache should not be used.
|
standard cache should not be used.
|
||||||
@@ -1271,6 +1273,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
if self.config.adapter_attn_dim is None:
|
if self.config.adapter_attn_dim is None:
|
||||||
raise ValueError(f"Cannot load_adapter for {target_lang} if `config.adapter_attn_dim` is not defined.")
|
raise ValueError(f"Cannot load_adapter for {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||||
|
|
||||||
|
if target_lang == self.target_lang and not force_load:
|
||||||
|
logger.warn(f"Adapter weights are already set to {target_lang}.")
|
||||||
|
return
|
||||||
|
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
@@ -1372,6 +1378,9 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
state_dict = {k: v.to(adapter_weights[k]) for k, v in state_dict.items()}
|
state_dict = {k: v.to(adapter_weights[k]) for k, v in state_dict.items()}
|
||||||
self.load_state_dict(state_dict, strict=False)
|
self.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
# set target language corectly
|
||||||
|
self.target_lang = target_lang
|
||||||
|
|
||||||
|
|
||||||
WAV_2_VEC_2_START_DOCSTRING = r"""
|
WAV_2_VEC_2_START_DOCSTRING = r"""
|
||||||
Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
|
Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
|
||||||
@@ -1854,12 +1863,14 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
|||||||
WAV_2_VEC_2_START_DOCSTRING,
|
WAV_2_VEC_2_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||||
def __init__(self, config, target_lang=None):
|
def __init__(self, config, target_lang: Optional[str] = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.wav2vec2 = Wav2Vec2Model(config)
|
self.wav2vec2 = Wav2Vec2Model(config)
|
||||||
self.dropout = nn.Dropout(config.final_dropout)
|
self.dropout = nn.Dropout(config.final_dropout)
|
||||||
|
|
||||||
|
self.target_lang = target_lang
|
||||||
|
|
||||||
if config.vocab_size is None:
|
if config.vocab_size is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||||
@@ -1872,15 +1883,29 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
"""
|
||||||
|
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||||||
|
passing `target_lang=...` to `from_pretrained(...)`.
|
||||||
|
|
||||||
|
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||||||
|
# correctly load adapter layers for Wav2Vec2 so that we do not have to introduce a new API to
|
||||||
|
# [`PreTrainedModel`]. While slightly hacky, Wav2Vec2 never has to tie input and output embeddings, so that it is
|
||||||
|
# ok to repurpose this function here.
|
||||||
|
target_lang = self.target_lang
|
||||||
|
|
||||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||||
logger.info("By default `target_lang` is set to 'eng'.")
|
logger.info("By default `target_lang` is set to 'eng'.")
|
||||||
elif target_lang is not None:
|
elif target_lang is not None:
|
||||||
self.load_adapter(target_lang)
|
self.load_adapter(target_lang, force_load=True)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def freeze_feature_extractor(self):
|
def freeze_feature_extractor(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1610,6 +1610,8 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
|
|||||||
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
||||||
self.dropout = nn.Dropout(config.final_dropout)
|
self.dropout = nn.Dropout(config.final_dropout)
|
||||||
|
|
||||||
|
self.target_lang = target_lang
|
||||||
|
|
||||||
if config.vocab_size is None:
|
if config.vocab_size is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||||
@@ -1622,13 +1624,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||||
|
|
||||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
|
||||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
|
||||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
|
||||||
logger.info("By default `target_lang` is set to 'eng'.")
|
|
||||||
elif target_lang is not None:
|
|
||||||
self.load_adapter(target_lang)
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
|||||||
@@ -1272,12 +1272,14 @@ class WavLMModel(WavLMPreTrainedModel):
|
|||||||
)
|
)
|
||||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
|
||||||
class WavLMForCTC(WavLMPreTrainedModel):
|
class WavLMForCTC(WavLMPreTrainedModel):
|
||||||
def __init__(self, config, target_lang=None):
|
def __init__(self, config, target_lang: Optional[str] = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.wavlm = WavLMModel(config)
|
self.wavlm = WavLMModel(config)
|
||||||
self.dropout = nn.Dropout(config.final_dropout)
|
self.dropout = nn.Dropout(config.final_dropout)
|
||||||
|
|
||||||
|
self.target_lang = target_lang
|
||||||
|
|
||||||
if config.vocab_size is None:
|
if config.vocab_size is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||||
@@ -1290,15 +1292,29 @@ class WavLMForCTC(WavLMPreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
"""
|
||||||
|
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||||||
|
passing `target_lang=...` to `from_pretrained(...)`.
|
||||||
|
|
||||||
|
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||||||
|
# correctly load adapter layers for WavLM so that we do not have to introduce a new API to
|
||||||
|
# [`PreTrainedModel`]. While slightly hacky, WavLM never has to tie input and output embeddings, so that it is
|
||||||
|
# ok to repurpose this function here.
|
||||||
|
target_lang = self.target_lang
|
||||||
|
|
||||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||||
logger.info("By default `target_lang` is set to 'eng'.")
|
logger.info("By default `target_lang` is set to 'eng'.")
|
||||||
elif target_lang is not None:
|
elif target_lang is not None:
|
||||||
self.load_adapter(target_lang)
|
self.load_adapter(target_lang, force_load=True)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def freeze_feature_extractor(self):
|
def freeze_feature_extractor(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1117,6 +1117,40 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_feed_forward_chunking(self):
|
def test_feed_forward_chunking(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_load_and_set_attn_adapter(self):
|
||||||
|
processor = Wav2Vec2Processor.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_logits(model, input_features):
|
||||||
|
model = model.to(torch_device)
|
||||||
|
batch = processor(
|
||||||
|
input_features,
|
||||||
|
padding=True,
|
||||||
|
sampling_rate=processor.feature_extractor.sampling_rate,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(
|
||||||
|
input_values=batch["input_values"].to(torch_device),
|
||||||
|
attention_mask=batch["attention_mask"].to(torch_device),
|
||||||
|
).logits
|
||||||
|
return logits
|
||||||
|
|
||||||
|
input_features = [np.random.random(16_000 * s) for s in [1, 3, 2, 6]]
|
||||||
|
|
||||||
|
model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter", target_lang="it")
|
||||||
|
|
||||||
|
logits = get_logits(model, input_features)
|
||||||
|
|
||||||
|
model_2 = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter")
|
||||||
|
model_2.load_adapter("it")
|
||||||
|
|
||||||
|
logits_2 = get_logits(model_2, input_features)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
|
||||||
|
|
||||||
def test_load_attn_adapter(self):
|
def test_load_attn_adapter(self):
|
||||||
processor = Wav2Vec2Processor.from_pretrained(
|
processor = Wav2Vec2Processor.from_pretrained(
|
||||||
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
|
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
|
||||||
|
|||||||
Reference in New Issue
Block a user