[modular] fixes! (#33820)
* fix converter for function definitions * small changes * no prints * style
This commit is contained in:
@@ -537,7 +537,7 @@ class ModularConverterTransformer(CSTTransformer):
|
||||
"feature_extractor": {},
|
||||
}
|
||||
self.match_patterns = "|".join(self.files.keys())
|
||||
self.all_functions = {}
|
||||
self.all_definitions = {}
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
"""When visiting imports from `transformers.models.xxx` we need to:
|
||||
@@ -647,6 +647,7 @@ class ModularConverterTransformer(CSTTransformer):
|
||||
node = class_finder.global_nodes.get(dependency, None)
|
||||
if node is not None:
|
||||
if dependency not in file_to_update:
|
||||
node = self.all_definitions.get(dependency, node)
|
||||
start_insert_idx -= 1
|
||||
file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node}
|
||||
elif dependency not in self.inserted_deps:
|
||||
@@ -683,6 +684,12 @@ class ModularConverterTransformer(CSTTransformer):
|
||||
self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||
return updated_node
|
||||
|
||||
def leave_FunctionDef(self, original_node, node):
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
if m.matches(parent_node, m.Module()):
|
||||
self.all_definitions[node.name.value] = node
|
||||
return node
|
||||
|
||||
def leave_If(self, original_node, node):
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
if m.matches(parent_node, m.Module()):
|
||||
@@ -757,7 +764,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--files_to_parse",
|
||||
default=["all"],
|
||||
default=["examples/modular-transformers/modular_dummy.py"],
|
||||
nargs="+",
|
||||
help="A list of `modular_xxxx` files that should be converted to single model file",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user