Modular: support for importing functions from any file (#35692)

* fix function imports

* improve comment

* Update modeling_switch_function.py

* make checks more robust

* improvement

* rename

* final test update
This commit is contained in:
Cyril Vallez
2025-01-16 16:37:53 +00:00
committed by GitHub
parent 8ebe9d7166
commit 91be6a5eb2
10 changed files with 305 additions and 43 deletions

View File

@@ -776,7 +776,7 @@ class ModelFileMapper(ModuleMapper):
else:
merged_dependencies.append(class_dep)
# Sort both list according to the order in their respective file
original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x])
original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines.get(x, 1e10))
merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x])
# Add all original node first, then merged ones
@@ -801,7 +801,7 @@ class ModelFileMapper(ModuleMapper):
else:
original_dependencies.append(dep)
# Sort both list according to the order in their respective file
original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x])
original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines.get(x, 1e10))
merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x])
# Add all original node first, then merged ones
@@ -1321,6 +1321,20 @@ class ModularFileMapper(ModuleMapper):
self.added_objects_file_mapping[dep] = file
self.functions[dep] = visited_module.global_nodes[dep]
# Add/overwrite the imported functions to other visited modules as well, in case it is absent/different
# in he modeling source file of the inherited class. See `examples/modular-tranformers/modular_switch_function.py`
# and `examples/modular-tranformers/modular_add_function.py` for examples
recursive_dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, set())
node_recursive_dependencies_mapping = {
dep: visited_module.global_nodes[dep] for dep in recursive_dependencies
}
for filename, module_mapper in self.visited_modules.items():
if filename != file:
module_mapper.global_nodes[object_name] = visited_module.functions[object_name]
if len(recursive_dependencies) > 0:
module_mapper.object_recursive_dependency_mapping[object_name] = recursive_dependencies
module_mapper.global_nodes.update(node_recursive_dependencies_mapping)
# Add assignments and their dependencies
elif object_name in visited_module.assignments and object_name not in self.assignments:
self.assignments[object_name] = visited_module.assignments[object_name]