Introduce modular files for speech models (#35902)
* WAV_2_VEC_2 to WAV2VEC2 * added modular files for hubert, wavlm, wav2vec2_bert, data2vec_audio * remove unnessary definitions in modulars * added modular files for UniSpeech, UniSpeechSat, Wav2Vec2Conformer * docstring fix for UniSpeechForCTC * removed unneccessary re-definition of modular classes * reverted lazy imports change on modular_model_converter, type-alias for Wav2Vec2BaseModelOutput * top-level import of deepspeed in seamless_m4t, speecht5 * avoid tracking imports inside classes, relocate lazy deepspeed, peft imports in their original locations * convert modular * tiny modular typing fixes * some more modular fixes * make style --------- Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com> Co-authored-by: Eustache Le Bihan <eulebihan@gmail.com>
This commit is contained in:
@@ -79,8 +79,11 @@ def preserve_case_replace(text, patterns: dict, default_name: str):
|
||||
|
||||
def get_cased_name(lowercase_name: str) -> str:
|
||||
"""From a model name in lowercase in the format `my_model`, return the cased name in the format `MyModel`."""
|
||||
alt_lowercase_name = lowercase_name.replace("_", "-")
|
||||
if lowercase_name in CONFIG_MAPPING_NAMES:
|
||||
return CONFIG_MAPPING_NAMES[lowercase_name].replace("Config", "")
|
||||
elif alt_lowercase_name in CONFIG_MAPPING_NAMES:
|
||||
return CONFIG_MAPPING_NAMES[alt_lowercase_name].replace("Config", "")
|
||||
else:
|
||||
return "".join(x.title() for x in lowercase_name.split("_"))
|
||||
|
||||
@@ -106,6 +109,8 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
|
||||
def __init__(self, old_name: str, new_name: str, original_new_model_name: str = "", only_doc: bool = False):
|
||||
super().__init__()
|
||||
old_name = old_name.replace("-", "_")
|
||||
new_name = new_name.replace("-", "_")
|
||||
self.old_name = old_name
|
||||
self.new_name = new_name
|
||||
self.cased_new_name = get_cased_name(self.new_name)
|
||||
@@ -535,7 +540,7 @@ def find_all_dependencies(
|
||||
|
||||
|
||||
# Top-level variables that match the following patterns will always use the value in the `modular_xxx.py` file
|
||||
ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC"]
|
||||
ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC", r"_HIDDEN_STATES_START_POSITION"]
|
||||
|
||||
# Top-level variables that match the following patterns will use the value in the `modular_xxx.py` file only if they are not None
|
||||
ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE = [r"_DOCSTRING"]
|
||||
@@ -616,6 +621,7 @@ class ModuleMapper(CSTVisitor, ABC):
|
||||
self.object_dependency_mapping = defaultdict(set) # immediate function/assignment dependency mapping (i.e. dependencies immediately in the function/assignment definition)
|
||||
self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes
|
||||
self.current_function = None # this keeps track of the current module-scope function
|
||||
self.current_class = None # this keeps track of the current module-scope class
|
||||
self.current_assignment = None # this keeps track of the current module-scope assignment
|
||||
# this keeps track of objects imported from modeling files (`from .configuration import Config`) -> `Config` should not be a dependency
|
||||
self.objects_imported_from_modeling = set()
|
||||
@@ -672,7 +678,7 @@ class ModuleMapper(CSTVisitor, ABC):
|
||||
|
||||
def visit_If(self, node):
|
||||
# If we are inside a function, do not add the import to the list of imports
|
||||
if self.current_function is None:
|
||||
if self.current_function is None and self.current_class is None:
|
||||
for stmt in node.body.body:
|
||||
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
|
||||
self.imports.append(node)
|
||||
@@ -680,6 +686,10 @@ class ModuleMapper(CSTVisitor, ABC):
|
||||
def visit_ClassDef(self, node: ClassDef) -> None:
|
||||
"""Record class nodes to create their dependencies at the end."""
|
||||
self.classes[node.name.value] = node
|
||||
self.current_class = node.name.value
|
||||
|
||||
def leave_ClassDef(self, node):
|
||||
self.current_class = None
|
||||
|
||||
def visit_Name(self, node: cst.Call):
|
||||
"""This is used to create a mapping from module-scope functions and assignments to objects used inside them."""
|
||||
@@ -1024,11 +1034,20 @@ def replace_class_node(
|
||||
new_decorators = (
|
||||
updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators
|
||||
)
|
||||
|
||||
# Keep return annotation in `modular_xxx.py` if any, else original return annotation
|
||||
new_return_annotation = updated_methods[name].returns if updated_methods[name].returns else func.returns
|
||||
|
||||
if not re.match(
|
||||
r"\ndef .*\(.*\):\n raise.*Error\(.*",
|
||||
mapper.python_module.code_for_node(updated_methods[name]),
|
||||
):
|
||||
func = func.with_changes(body=updated_methods[name].body, params=new_params, decorators=new_decorators)
|
||||
func = func.with_changes(
|
||||
body=updated_methods[name].body,
|
||||
params=new_params,
|
||||
decorators=new_decorators,
|
||||
returns=new_return_annotation,
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -1136,7 +1155,7 @@ def append_new_import_node(
|
||||
import_node = node.body[0]
|
||||
names_to_keep = []
|
||||
for name in import_node.names:
|
||||
name_value = name.evaluated_name
|
||||
name_value = name.evaluated_alias or name.evaluated_name
|
||||
if name_value not in unused_imports and name_value not in added_names:
|
||||
names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT))
|
||||
added_names.add(name_value)
|
||||
|
||||
Reference in New Issue
Block a user