Gemma capping (#34282)
* softcapping * soft cap before the mask * style * ... * super nit * update * fixes * update * small issue with modular * fix modular imports * update * fixup * simplify a hell lot * simplify cleaning imports * finish fixing * update our design * nits * use a deprecation cycle * updates * Fix modular (recursive deps need to always be computed after merges!) * push * fix * update * fix modular order * make fix-copies * updates * update * ? * don't compile for now * ? * fix some stuff * donc! * fix copies * update * fixup * ? * fix two tests * fix? * for now, don't use head info * eager when output attentoin and sdpa or flash as it's the simplest behaviour (for our tests as well :)) * fix-copies * revert sdpa check * Apply suggestions from code review Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co> * rebase, fix-copies and push * add a slow integration test * update the test * fix left padding issue * fix test * remove duplicate scaling * quality * add a small test and make sure it works * 2b --------- Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
@@ -153,9 +153,9 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
|
||||
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
|
||||
# Handle ClassB.call_to_method
|
||||
if (
|
||||
isinstance(original_node.value, cst.Name)
|
||||
m.matches(original_node.value, m.Name())
|
||||
and original_node.value.value in self.all_bases
|
||||
and isinstance(original_node.attr, cst.Name)
|
||||
and m.matches(original_node.attr, m.Name())
|
||||
):
|
||||
# Replace with super().call_to_method
|
||||
return updated_node.with_changes(
|
||||
@@ -163,10 +163,10 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
|
||||
)
|
||||
# Handle ClassB().call_to_method
|
||||
elif (
|
||||
isinstance(original_node.value, cst.Call)
|
||||
and isinstance(original_node.value.func, cst.Name)
|
||||
m.matches(original_node.value, m.Call())
|
||||
and m.matches(original_node.value.func, m.Name())
|
||||
and original_node.value.func.value in self.all_bases
|
||||
and isinstance(original_node.attr, cst.Name)
|
||||
and m.matches(original_node.attr, m.Name())
|
||||
):
|
||||
# Replace with super().call_to_method
|
||||
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
|
||||
@@ -174,16 +174,16 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
|
||||
|
||||
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
|
||||
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
|
||||
if isinstance(original_node.func, cst.Attribute) and (
|
||||
if m.matches(original_node.func, m.Attribute()) and (
|
||||
# Match ClassB().func_a(...)
|
||||
(
|
||||
isinstance(original_node.func.value, cst.Call)
|
||||
and isinstance(original_node.func.value.func, cst.Name)
|
||||
m.matches(original_node.func.value, m.Call())
|
||||
and m.matches(original_node.func.value.func, m.Name())
|
||||
and original_node.func.value.func.value in self.all_bases
|
||||
)
|
||||
or
|
||||
# Match ClassB.func_a(...)
|
||||
(isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases)
|
||||
(m.matches(original_node.func.value, m.Name()) and original_node.func.value.value in self.all_bases)
|
||||
):
|
||||
# Check if the first argument is 'self', and remove it
|
||||
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
|
||||
@@ -632,8 +632,10 @@ class ModuleMapper(CSTVisitor, ABC):
|
||||
for id, node in self.global_nodes.items():
|
||||
self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line
|
||||
|
||||
# Since we added every Name as part of `self.object_dependency_mapping`, we now remove those that
|
||||
# are not part of the recorded objects (i.e. built-in variables, imports, etc)
|
||||
def _restrict_dependencies_to_known_entities(self):
|
||||
"""Since we added every Name as part of `self.object_dependency_mapping`, we need to remove those that
|
||||
are not part of the recorded objects in `self.global_nodes` (i.e. built-in variables, imports, etc).
|
||||
This should be called only after all merging operations have been finalized!!"""
|
||||
global_objects = set(self.global_nodes.keys())
|
||||
for object_name, dependencies in self.object_dependency_mapping.items():
|
||||
self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects}
|
||||
@@ -814,6 +816,8 @@ class ModelFileMapper(ModuleMapper):
|
||||
# Correctly re-set the global nodes at this point
|
||||
self.global_nodes.update(self.functions)
|
||||
self.global_nodes.update(self.assignments)
|
||||
# Restrict the dependency mappings to the know entities to avoid Python's built-ins
|
||||
self._restrict_dependencies_to_known_entities()
|
||||
# Create the global mapping of recursive dependencies for functions and assignments
|
||||
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
|
||||
|
||||
@@ -1142,22 +1146,20 @@ class ModularFileMapper(ModuleMapper):
|
||||
if assigned_variable == "__all__":
|
||||
self.all_all_to_add = split_all_assignment(node)
|
||||
else:
|
||||
self.current_assignment = assigned_variable
|
||||
self.assignments[assigned_variable] = node
|
||||
|
||||
def leave_Module(self, node):
|
||||
"""When we leave the modular file, we do the following in order:
|
||||
1. compute the nested (recursive) function and assignment dependencies
|
||||
2. for each modeling file found in the imports, rename it with the new model name, visit it, and update
|
||||
1. for each modeling file found in the imports, rename it with the new model name, visit it, and update
|
||||
its dependency graph with the new function and assignment definitions found in the modular
|
||||
3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files)
|
||||
2. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files)
|
||||
3. compute the nested (recursive) function and assignment dependencies
|
||||
"""
|
||||
# Takes care of finalizing our visit
|
||||
super().leave_Module(node)
|
||||
|
||||
# 1. compute the nested (recursive) function and assignment dependencies
|
||||
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
|
||||
|
||||
# 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
|
||||
# 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
|
||||
self.visited_modules = {}
|
||||
self.renamers = {}
|
||||
for file, module in self.model_specific_modules.items():
|
||||
@@ -1177,10 +1179,13 @@ class ModularFileMapper(ModuleMapper):
|
||||
# We record it so that we can rename classes later the exact same way
|
||||
self.renamers[file] = renamer
|
||||
|
||||
# 3. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
|
||||
# 2. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
|
||||
# definitions found in the visited files
|
||||
self.merge_model_specific_imports(self.visited_modules)
|
||||
|
||||
# 3. compute the nested (recursive) function and assignment dependencies
|
||||
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
|
||||
|
||||
# We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later
|
||||
# Note that we may visit several of the same file types, thus we save them per file type, not file
|
||||
self.imported_objects_per_file = defaultdict(set)
|
||||
@@ -1200,9 +1205,9 @@ class ModularFileMapper(ModuleMapper):
|
||||
if object_name in visited_module.functions and object_name not in self.functions:
|
||||
self.functions[object_name] = visited_module.functions[object_name]
|
||||
self.added_objects_file_mapping[object_name] = file
|
||||
dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None)
|
||||
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
|
||||
if dependencies is not None:
|
||||
self.object_recursive_dependency_mapping[object_name] = dependencies
|
||||
self.object_dependency_mapping[object_name] = dependencies
|
||||
for dep in dependencies:
|
||||
if dep not in self.global_nodes:
|
||||
self.added_objects_file_mapping[dep] = file
|
||||
@@ -1212,9 +1217,9 @@ class ModularFileMapper(ModuleMapper):
|
||||
elif object_name in visited_module.assignments and object_name not in self.assignments:
|
||||
self.assignments[object_name] = visited_module.assignments[object_name]
|
||||
self.added_objects_file_mapping[object_name] = file
|
||||
dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None)
|
||||
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
|
||||
if dependencies is not None:
|
||||
self.object_recursive_dependency_mapping[object_name] = dependencies
|
||||
self.object_dependency_mapping[object_name] = dependencies
|
||||
for dep in dependencies:
|
||||
if dep not in self.global_nodes:
|
||||
self.added_objects_file_mapping[dep] = file
|
||||
@@ -1222,6 +1227,8 @@ class ModularFileMapper(ModuleMapper):
|
||||
|
||||
# Do not forget to re-assign all nodes after the merge
|
||||
self.global_nodes = {**self.assignments, **self.classes, **self.functions}
|
||||
# And restric dependencies to those nodes only
|
||||
self._restrict_dependencies_to_known_entities()
|
||||
|
||||
def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]:
|
||||
"""Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that
|
||||
@@ -1239,10 +1246,11 @@ class ModularFileMapper(ModuleMapper):
|
||||
else:
|
||||
original_dependencies.append(dep)
|
||||
# Sort all lists according to the order in their respective file
|
||||
all_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x])
|
||||
all_dependencies = []
|
||||
for file, dependencies in other_files_dependencies.items():
|
||||
sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x])
|
||||
all_dependencies += sorted_dependencies
|
||||
all_dependencies += sorted(original_dependencies, key=lambda x: self.start_lines[x])
|
||||
|
||||
# Add all original node first, then merged ones (one file at a time)
|
||||
for dep in all_dependencies:
|
||||
@@ -1485,7 +1493,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--files_to_parse",
|
||||
default=["src/transformers/models/gemma/modular_gemma.py"],
|
||||
default=["src/transformers/models/gemma2/modular_gemma2.py"],
|
||||
nargs="+",
|
||||
help="A list of `modular_xxxx` files that should be converted to single model file",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user