Modular fix (#34802)

* Modular fix

* style

* remove logger warning

* Update modular_model_converter.py
This commit is contained in:
Cyril Vallez
2024-11-19 16:08:57 +01:00
committed by GitHub
parent ce1d328e3b
commit e3a5889ef0
6 changed files with 65 additions and 113 deletions

View File

@@ -266,7 +266,6 @@ class SuperTransformer(cst.CSTTransformer):
if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])):
target = self.python_module.code_for_node(stmt.body[0].targets[0].target)
if target in self.deleted_targets:
logger.warning(f"Deleted the assign for {target}")
continue
if target in self.all_assign_target:
stmt = self.all_assign_target[target]
@@ -773,6 +772,8 @@ class ModelFileMapper(ModuleMapper):
self.object_dependency_mapping.update(
{obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()}
)
# Add them to global nodes
self.global_nodes.update(self.functions)
def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]):
"""Update the global nodes with the assignment from the modular file.
@@ -786,6 +787,8 @@ class ModelFileMapper(ModuleMapper):
self.assignments[assignment] = node
if assignment in object_mapping:
self.object_dependency_mapping[assignment] = object_mapping[assignment]
# Add them to global nodes
self.global_nodes.update(self.assignments)
def _merge_classes(self, classes: dict[str, cst.CSTNode]):
"""Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and
@@ -813,10 +816,7 @@ class ModelFileMapper(ModuleMapper):
self._merge_classes(classes)
self.modular_file_start_lines = start_lines
# Correctly re-set the global nodes at this point
self.global_nodes.update(self.functions)
self.global_nodes.update(self.assignments)
# Restrict the dependency mappings to the know entities to avoid Python's built-ins
# Restrict the dependency mappings to the known entities to avoid Python's built-ins and imports
self._restrict_dependencies_to_known_entities()
# Create the global mapping of recursive dependencies for functions and assignments
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
@@ -1024,14 +1024,17 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) ->
import_ref_count[name] = ref_count
imports_to_keep = []
existing_protected_statements = set() # str repr of the import nodes - does not work with the nodes directly
for node in all_imports:
if m.matches(node, m.If()): # handle safe imports
new_statements = []
for stmt_node in node.body.body:
append_new_import_node(stmt_node, unused_imports, new_statements)
new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements]
if len(new_statements) > 0:
new_node = node.with_changes(body=node.body.with_changes(body=new_statements))
imports_to_keep.append(new_node)
existing_protected_statements.update({str(stmt) for stmt in new_statements})
else:
append_new_import_node(node, unused_imports, imports_to_keep)