Introduce modular files for speech models (#35902)

* WAV_2_VEC_2 to WAV2VEC2

* added modular files for hubert, wavlm, wav2vec2_bert, data2vec_audio

* remove unnessary definitions in modulars

* added modular files for UniSpeech, UniSpeechSat, Wav2Vec2Conformer

* docstring fix for UniSpeechForCTC

* removed unneccessary re-definition of modular classes

* reverted lazy imports change on modular_model_converter, type-alias for Wav2Vec2BaseModelOutput

* top-level import of deepspeed in seamless_m4t, speecht5

* avoid tracking imports inside classes, relocate lazy deepspeed, peft imports in their original locations

* convert modular

* tiny modular typing fixes

* some more modular fixes

* make style

---------

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
Co-authored-by: Eustache Le Bihan <eulebihan@gmail.com>
This commit is contained in:
Nikos Antoniou
2025-04-04 12:46:27 +03:00
committed by GitHub
parent d130cd0e16
commit f74d7da836
43 changed files with 6690 additions and 2216 deletions

View File

@@ -79,8 +79,11 @@ def preserve_case_replace(text, patterns: dict, default_name: str):
def get_cased_name(lowercase_name: str) -> str:
"""From a model name in lowercase in the format `my_model`, return the cased name in the format `MyModel`."""
alt_lowercase_name = lowercase_name.replace("_", "-")
if lowercase_name in CONFIG_MAPPING_NAMES:
return CONFIG_MAPPING_NAMES[lowercase_name].replace("Config", "")
elif alt_lowercase_name in CONFIG_MAPPING_NAMES:
return CONFIG_MAPPING_NAMES[alt_lowercase_name].replace("Config", "")
else:
return "".join(x.title() for x in lowercase_name.split("_"))
@@ -106,6 +109,8 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
def __init__(self, old_name: str, new_name: str, original_new_model_name: str = "", only_doc: bool = False):
super().__init__()
old_name = old_name.replace("-", "_")
new_name = new_name.replace("-", "_")
self.old_name = old_name
self.new_name = new_name
self.cased_new_name = get_cased_name(self.new_name)
@@ -535,7 +540,7 @@ def find_all_dependencies(
# Top-level variables that match the following patterns will always use the value in the `modular_xxx.py` file
ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC"]
ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC", r"_HIDDEN_STATES_START_POSITION"]
# Top-level variables that match the following patterns will use the value in the `modular_xxx.py` file only if they are not None
ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE = [r"_DOCSTRING"]
@@ -616,6 +621,7 @@ class ModuleMapper(CSTVisitor, ABC):
self.object_dependency_mapping = defaultdict(set) # immediate function/assignment dependency mapping (i.e. dependencies immediately in the function/assignment definition)
self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes
self.current_function = None # this keeps track of the current module-scope function
self.current_class = None # this keeps track of the current module-scope class
self.current_assignment = None # this keeps track of the current module-scope assignment
# this keeps track of objects imported from modeling files (`from .configuration import Config`) -> `Config` should not be a dependency
self.objects_imported_from_modeling = set()
@@ -672,7 +678,7 @@ class ModuleMapper(CSTVisitor, ABC):
def visit_If(self, node):
# If we are inside a function, do not add the import to the list of imports
if self.current_function is None:
if self.current_function is None and self.current_class is None:
for stmt in node.body.body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
self.imports.append(node)
@@ -680,6 +686,10 @@ class ModuleMapper(CSTVisitor, ABC):
def visit_ClassDef(self, node: ClassDef) -> None:
"""Record class nodes to create their dependencies at the end."""
self.classes[node.name.value] = node
self.current_class = node.name.value
def leave_ClassDef(self, node):
self.current_class = None
def visit_Name(self, node: cst.Call):
"""This is used to create a mapping from module-scope functions and assignments to objects used inside them."""
@@ -1024,11 +1034,20 @@ def replace_class_node(
new_decorators = (
updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators
)
# Keep return annotation in `modular_xxx.py` if any, else original return annotation
new_return_annotation = updated_methods[name].returns if updated_methods[name].returns else func.returns
if not re.match(
r"\ndef .*\(.*\):\n raise.*Error\(.*",
mapper.python_module.code_for_node(updated_methods[name]),
):
func = func.with_changes(body=updated_methods[name].body, params=new_params, decorators=new_decorators)
func = func.with_changes(
body=updated_methods[name].body,
params=new_params,
decorators=new_decorators,
returns=new_return_annotation,
)
else:
continue
@@ -1136,7 +1155,7 @@ def append_new_import_node(
import_node = node.body[0]
names_to_keep = []
for name in import_node.names:
name_value = name.evaluated_name
name_value = name.evaluated_alias or name.evaluated_name
if name_value not in unused_imports and name_value not in added_names:
names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT))
added_names.add(name_value)