[XLSR-Wav2Vec2] Add multi-lingual Wav2Vec2 models (#10648)
* add conversion script * add wav2vec2 xslr models * finish * Update docs/source/model_doc/xlsr_wav2vec2.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
63c295ac05
commit
602d63f05c
@@ -90,7 +90,8 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
||||
else:
|
||||
for key, mapped_key in MAPPING.items():
|
||||
mapped_key = "wav2vec2." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
|
||||
if key in name:
|
||||
|
||||
if key in name or (key.split("w2v_model.")[-1] == name.split(".")[0] and not is_finetuned):
|
||||
is_used = True
|
||||
if "*" in mapped_key:
|
||||
layer_index = name.split(key)[0].split(".")[-2]
|
||||
@@ -110,7 +111,7 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
||||
if not is_used:
|
||||
unused_weights.append(name)
|
||||
|
||||
logger.info("Unused weights", unused_weights)
|
||||
logger.warn(f"Unused weights: {unused_weights}")
|
||||
|
||||
|
||||
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
|
||||
|
||||
Reference in New Issue
Block a user