[Wav2Vec2 - MMS] Correct directly loading adapters weights (#24335)
* Correct direct lang loading * correct more * revert black * Use tie weights instead= * add tests * add tests * make style
This commit is contained in:
committed by
GitHub
parent
e5c760d636
commit
b0513b013b
@@ -1134,12 +1134,14 @@ class HubertModel(HubertPreTrainedModel):
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
|
||||
class HubertForCTC(HubertPreTrainedModel):
|
||||
def __init__(self, config, target_lang=None):
|
||||
def __init__(self, config, target_lang: Optional[str] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.hubert = HubertModel(config)
|
||||
self.dropout = nn.Dropout(config.final_dropout)
|
||||
|
||||
self.target_lang = target_lang
|
||||
|
||||
if config.vocab_size is None:
|
||||
raise ValueError(
|
||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||
@@ -1152,15 +1154,29 @@ class HubertForCTC(HubertPreTrainedModel):
|
||||
)
|
||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||||
passing `target_lang=...` to `from_pretrained(...)`.
|
||||
|
||||
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||||
"""
|
||||
|
||||
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||||
# correctly load adapter layers for Hubert so that we do not have to introduce a new API to
|
||||
# [`PreTrainedModel`]. While slightly hacky, Hubert never has to tie input and output embeddings, so that it is
|
||||
# ok to repurpose this function here.
|
||||
target_lang = self.target_lang
|
||||
|
||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||
logger.info("By default `target_lang` is set to 'eng'.")
|
||||
elif target_lang is not None:
|
||||
self.load_adapter(target_lang)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
self.load_adapter(target_lang, force_load=True)
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
|
||||
@@ -969,12 +969,14 @@ class SEWModel(SEWPreTrainedModel):
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW
|
||||
class SEWForCTC(SEWPreTrainedModel):
|
||||
def __init__(self, config, target_lang=None):
|
||||
def __init__(self, config, target_lang: Optional[str] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.sew = SEWModel(config)
|
||||
self.dropout = nn.Dropout(config.final_dropout)
|
||||
|
||||
self.target_lang = target_lang
|
||||
|
||||
if config.vocab_size is None:
|
||||
raise ValueError(
|
||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||
@@ -987,15 +989,29 @@ class SEWForCTC(SEWPreTrainedModel):
|
||||
)
|
||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||||
passing `target_lang=...` to `from_pretrained(...)`.
|
||||
|
||||
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||||
"""
|
||||
|
||||
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||||
# correctly load adapter layers for SEW so that we do not have to introduce a new API to
|
||||
# [`PreTrainedModel`]. While slightly hacky, SEW never has to tie input and output embeddings, so that it is
|
||||
# ok to repurpose this function here.
|
||||
target_lang = self.target_lang
|
||||
|
||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||
logger.info("By default `target_lang` is set to 'eng'.")
|
||||
elif target_lang is not None:
|
||||
self.load_adapter(target_lang)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
self.load_adapter(target_lang, force_load=True)
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
|
||||
@@ -1509,12 +1509,14 @@ class SEWDModel(SEWDPreTrainedModel):
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD
|
||||
class SEWDForCTC(SEWDPreTrainedModel):
|
||||
def __init__(self, config, target_lang=None):
|
||||
def __init__(self, config, target_lang: Optional[str] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.sew_d = SEWDModel(config)
|
||||
self.dropout = nn.Dropout(config.final_dropout)
|
||||
|
||||
self.target_lang = target_lang
|
||||
|
||||
if config.vocab_size is None:
|
||||
raise ValueError(
|
||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||
@@ -1527,15 +1529,29 @@ class SEWDForCTC(SEWDPreTrainedModel):
|
||||
)
|
||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||||
passing `target_lang=...` to `from_pretrained(...)`.
|
||||
|
||||
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||||
"""
|
||||
|
||||
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||||
# correctly load adapter layers for SEWD so that we do not have to introduce a new API to
|
||||
# [`PreTrainedModel`]. While slightly hacky, SEWD never has to tie input and output embeddings, so that it is
|
||||
# ok to repurpose this function here.
|
||||
target_lang = self.target_lang
|
||||
|
||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||
logger.info("By default `target_lang` is set to 'eng'.")
|
||||
elif target_lang is not None:
|
||||
self.load_adapter(target_lang)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
self.load_adapter(target_lang, force_load=True)
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
|
||||
@@ -1378,12 +1378,14 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeech, wav2vec2->unispeech, WAV_2_VEC_2->UNISPEECH
|
||||
class UniSpeechForCTC(UniSpeechPreTrainedModel):
|
||||
def __init__(self, config, target_lang=None):
|
||||
def __init__(self, config, target_lang: Optional[str] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.unispeech = UniSpeechModel(config)
|
||||
self.dropout = nn.Dropout(config.final_dropout)
|
||||
|
||||
self.target_lang = target_lang
|
||||
|
||||
if config.vocab_size is None:
|
||||
raise ValueError(
|
||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||
@@ -1396,15 +1398,29 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
|
||||
)
|
||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||||
passing `target_lang=...` to `from_pretrained(...)`.
|
||||
|
||||
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||||
"""
|
||||
|
||||
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||||
# correctly load adapter layers for UniSpeech so that we do not have to introduce a new API to
|
||||
# [`PreTrainedModel`]. While slightly hacky, UniSpeech never has to tie input and output embeddings, so that it is
|
||||
# ok to repurpose this function here.
|
||||
target_lang = self.target_lang
|
||||
|
||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||
logger.info("By default `target_lang` is set to 'eng'.")
|
||||
elif target_lang is not None:
|
||||
self.load_adapter(target_lang)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
self.load_adapter(target_lang, force_load=True)
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
|
||||
@@ -1385,12 +1385,14 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
|
||||
class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
|
||||
def __init__(self, config, target_lang=None):
|
||||
def __init__(self, config, target_lang: Optional[str] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.unispeech_sat = UniSpeechSatModel(config)
|
||||
self.dropout = nn.Dropout(config.final_dropout)
|
||||
|
||||
self.target_lang = target_lang
|
||||
|
||||
if config.vocab_size is None:
|
||||
raise ValueError(
|
||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||
@@ -1403,15 +1405,29 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
|
||||
)
|
||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||||
passing `target_lang=...` to `from_pretrained(...)`.
|
||||
|
||||
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||||
"""
|
||||
|
||||
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||||
# correctly load adapter layers for UniSpeechSat so that we do not have to introduce a new API to
|
||||
# [`PreTrainedModel`]. While slightly hacky, UniSpeechSat never has to tie input and output embeddings, so that it is
|
||||
# ok to repurpose this function here.
|
||||
target_lang = self.target_lang
|
||||
|
||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||
logger.info("By default `target_lang` is set to 'eng'.")
|
||||
elif target_lang is not None:
|
||||
self.load_adapter(target_lang)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
self.load_adapter(target_lang, force_load=True)
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
|
||||
@@ -1207,7 +1207,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
if isinstance(self, Wav2Vec2ForCTC):
|
||||
self._init_weights(self.lm_head)
|
||||
|
||||
def load_adapter(self, target_lang: str, **kwargs):
|
||||
def load_adapter(self, target_lang: str, force_load=True, **kwargs):
|
||||
r"""
|
||||
Load a language adapter model from a pre-trained adapter model.
|
||||
|
||||
@@ -1215,6 +1215,8 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
target_lang (`str`):
|
||||
Has to be a language id of an existing adapter weight. Adapter weights are stored in the format
|
||||
adapter.<lang>.safetensors or adapter.<lang>.bin
|
||||
force_load (`bool`, defaults to `True`):
|
||||
Whether the weights shall be loaded even if `target_lang` matches `self.target_lang`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
@@ -1271,6 +1273,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
if self.config.adapter_attn_dim is None:
|
||||
raise ValueError(f"Cannot load_adapter for {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||
|
||||
if target_lang == self.target_lang and not force_load:
|
||||
logger.warn(f"Adapter weights are already set to {target_lang}.")
|
||||
return
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
@@ -1372,6 +1378,9 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
state_dict = {k: v.to(adapter_weights[k]) for k, v in state_dict.items()}
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# set target language corectly
|
||||
self.target_lang = target_lang
|
||||
|
||||
|
||||
WAV_2_VEC_2_START_DOCSTRING = r"""
|
||||
Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
|
||||
@@ -1854,12 +1863,14 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
||||
WAV_2_VEC_2_START_DOCSTRING,
|
||||
)
|
||||
class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
def __init__(self, config, target_lang=None):
|
||||
def __init__(self, config, target_lang: Optional[str] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.wav2vec2 = Wav2Vec2Model(config)
|
||||
self.dropout = nn.Dropout(config.final_dropout)
|
||||
|
||||
self.target_lang = target_lang
|
||||
|
||||
if config.vocab_size is None:
|
||||
raise ValueError(
|
||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||
@@ -1872,15 +1883,29 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
)
|
||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||||
passing `target_lang=...` to `from_pretrained(...)`.
|
||||
|
||||
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||||
"""
|
||||
|
||||
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||||
# correctly load adapter layers for Wav2Vec2 so that we do not have to introduce a new API to
|
||||
# [`PreTrainedModel`]. While slightly hacky, Wav2Vec2 never has to tie input and output embeddings, so that it is
|
||||
# ok to repurpose this function here.
|
||||
target_lang = self.target_lang
|
||||
|
||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||
logger.info("By default `target_lang` is set to 'eng'.")
|
||||
elif target_lang is not None:
|
||||
self.load_adapter(target_lang)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
self.load_adapter(target_lang, force_load=True)
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
|
||||
@@ -1610,6 +1610,8 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
|
||||
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
||||
self.dropout = nn.Dropout(config.final_dropout)
|
||||
|
||||
self.target_lang = target_lang
|
||||
|
||||
if config.vocab_size is None:
|
||||
raise ValueError(
|
||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||
@@ -1622,13 +1624,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
|
||||
)
|
||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||
|
||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||
logger.info("By default `target_lang` is set to 'eng'.")
|
||||
elif target_lang is not None:
|
||||
self.load_adapter(target_lang)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
||||
@@ -1272,12 +1272,14 @@ class WavLMModel(WavLMPreTrainedModel):
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
|
||||
class WavLMForCTC(WavLMPreTrainedModel):
|
||||
def __init__(self, config, target_lang=None):
|
||||
def __init__(self, config, target_lang: Optional[str] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.wavlm = WavLMModel(config)
|
||||
self.dropout = nn.Dropout(config.final_dropout)
|
||||
|
||||
self.target_lang = target_lang
|
||||
|
||||
if config.vocab_size is None:
|
||||
raise ValueError(
|
||||
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||||
@@ -1290,15 +1292,29 @@ class WavLMForCTC(WavLMPreTrainedModel):
|
||||
)
|
||||
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||||
passing `target_lang=...` to `from_pretrained(...)`.
|
||||
|
||||
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||||
"""
|
||||
|
||||
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||||
# correctly load adapter layers for WavLM so that we do not have to introduce a new API to
|
||||
# [`PreTrainedModel`]. While slightly hacky, WavLM never has to tie input and output embeddings, so that it is
|
||||
# ok to repurpose this function here.
|
||||
target_lang = self.target_lang
|
||||
|
||||
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||||
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||||
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||||
logger.info("By default `target_lang` is set to 'eng'.")
|
||||
elif target_lang is not None:
|
||||
self.load_adapter(target_lang)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
self.load_adapter(target_lang, force_load=True)
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
|
||||
@@ -1117,6 +1117,40 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
def test_load_and_set_attn_adapter(self):
|
||||
processor = Wav2Vec2Processor.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
|
||||
)
|
||||
|
||||
def get_logits(model, input_features):
|
||||
model = model.to(torch_device)
|
||||
batch = processor(
|
||||
input_features,
|
||||
padding=True,
|
||||
sampling_rate=processor.feature_extractor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(
|
||||
input_values=batch["input_values"].to(torch_device),
|
||||
attention_mask=batch["attention_mask"].to(torch_device),
|
||||
).logits
|
||||
return logits
|
||||
|
||||
input_features = [np.random.random(16_000 * s) for s in [1, 3, 2, 6]]
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter", target_lang="it")
|
||||
|
||||
logits = get_logits(model, input_features)
|
||||
|
||||
model_2 = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter")
|
||||
model_2.load_adapter("it")
|
||||
|
||||
logits_2 = get_logits(model_2, input_features)
|
||||
|
||||
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
|
||||
|
||||
def test_load_attn_adapter(self):
|
||||
processor = Wav2Vec2Processor.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
|
||||
|
||||
Reference in New Issue
Block a user