[modular] fixes! (#33820)

* fix converter for function definitions

* small changes

* no prints

* style
This commit is contained in:
Arthur
2024-09-30 16:43:55 +02:00
committed by GitHub
parent 1d29a75a6a
commit 1dba608df9
9 changed files with 322 additions and 247 deletions

View File

@@ -537,7 +537,7 @@ class ModularConverterTransformer(CSTTransformer):
"feature_extractor": {},
}
self.match_patterns = "|".join(self.files.keys())
self.all_functions = {}
self.all_definitions = {}
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
"""When visiting imports from `transformers.models.xxx` we need to:
@@ -647,6 +647,7 @@ class ModularConverterTransformer(CSTTransformer):
node = class_finder.global_nodes.get(dependency, None)
if node is not None:
if dependency not in file_to_update:
node = self.all_definitions.get(dependency, node)
start_insert_idx -= 1
file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node}
elif dependency not in self.inserted_deps:
@@ -683,6 +684,12 @@ class ModularConverterTransformer(CSTTransformer):
self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
return updated_node
def leave_FunctionDef(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
self.all_definitions[node.name.value] = node
return node
def leave_If(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
@@ -757,7 +764,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["all"],
default=["examples/modular-transformers/modular_dummy.py"],
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)