Gemma capping (#34282)

* softcapping

* soft cap before the mask

* style

* ...

* super nit

* update

* fixes

* update

* small issue with modular

* fix modular imports

* update

* fixup

* simplify a hell lot

* simplify cleaning imports

* finish fixing

* update our design

* nits

* use a deprecation cycle

* updates

* Fix modular (recursive deps need to always be computed after merges!)

* push

* fix

* update

* fix modular order

* make fix-copies

* updates

* update

* ?

* don't compile for now

* ?

* fix some stuff

* donc!

* fix copies

* update

* fixup

* ?

* fix two tests

* fix?

* for now, don't use head info

* eager when output attentoin and sdpa or flash as it's the simplest behaviour (for our tests as well :))

* fix-copies

* revert sdpa check

* Apply suggestions from code review

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>

* rebase, fix-copies and push

* add a slow integration test

* update the test

* fix left padding issue

* fix test

* remove duplicate scaling

* quality

* add a small test and make sure it works

* 2b

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
Arthur
2024-11-19 13:52:38 +01:00
committed by GitHub
parent 54739a320e
commit 4bff54f921
8 changed files with 431 additions and 541 deletions

View File

@@ -153,9 +153,9 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
# Handle ClassB.call_to_method
if (
isinstance(original_node.value, cst.Name)
m.matches(original_node.value, m.Name())
and original_node.value.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
and m.matches(original_node.attr, m.Name())
):
# Replace with super().call_to_method
return updated_node.with_changes(
@@ -163,10 +163,10 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
)
# Handle ClassB().call_to_method
elif (
isinstance(original_node.value, cst.Call)
and isinstance(original_node.value.func, cst.Name)
m.matches(original_node.value, m.Call())
and m.matches(original_node.value.func, m.Name())
and original_node.value.func.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
and m.matches(original_node.attr, m.Name())
):
# Replace with super().call_to_method
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
@@ -174,16 +174,16 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
if isinstance(original_node.func, cst.Attribute) and (
if m.matches(original_node.func, m.Attribute()) and (
# Match ClassB().func_a(...)
(
isinstance(original_node.func.value, cst.Call)
and isinstance(original_node.func.value.func, cst.Name)
m.matches(original_node.func.value, m.Call())
and m.matches(original_node.func.value.func, m.Name())
and original_node.func.value.func.value in self.all_bases
)
or
# Match ClassB.func_a(...)
(isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases)
(m.matches(original_node.func.value, m.Name()) and original_node.func.value.value in self.all_bases)
):
# Check if the first argument is 'self', and remove it
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
@@ -632,8 +632,10 @@ class ModuleMapper(CSTVisitor, ABC):
for id, node in self.global_nodes.items():
self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line
# Since we added every Name as part of `self.object_dependency_mapping`, we now remove those that
# are not part of the recorded objects (i.e. built-in variables, imports, etc)
def _restrict_dependencies_to_known_entities(self):
"""Since we added every Name as part of `self.object_dependency_mapping`, we need to remove those that
are not part of the recorded objects in `self.global_nodes` (i.e. built-in variables, imports, etc).
This should be called only after all merging operations have been finalized!!"""
global_objects = set(self.global_nodes.keys())
for object_name, dependencies in self.object_dependency_mapping.items():
self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects}
@@ -814,6 +816,8 @@ class ModelFileMapper(ModuleMapper):
# Correctly re-set the global nodes at this point
self.global_nodes.update(self.functions)
self.global_nodes.update(self.assignments)
# Restrict the dependency mappings to the know entities to avoid Python's built-ins
self._restrict_dependencies_to_known_entities()
# Create the global mapping of recursive dependencies for functions and assignments
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
@@ -1142,22 +1146,20 @@ class ModularFileMapper(ModuleMapper):
if assigned_variable == "__all__":
self.all_all_to_add = split_all_assignment(node)
else:
self.current_assignment = assigned_variable
self.assignments[assigned_variable] = node
def leave_Module(self, node):
"""When we leave the modular file, we do the following in order:
1. compute the nested (recursive) function and assignment dependencies
2. for each modeling file found in the imports, rename it with the new model name, visit it, and update
1. for each modeling file found in the imports, rename it with the new model name, visit it, and update
its dependency graph with the new function and assignment definitions found in the modular
3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files)
2. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files)
3. compute the nested (recursive) function and assignment dependencies
"""
# Takes care of finalizing our visit
super().leave_Module(node)
# 1. compute the nested (recursive) function and assignment dependencies
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
# 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
# 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
self.visited_modules = {}
self.renamers = {}
for file, module in self.model_specific_modules.items():
@@ -1177,10 +1179,13 @@ class ModularFileMapper(ModuleMapper):
# We record it so that we can rename classes later the exact same way
self.renamers[file] = renamer
# 3. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
# 2. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
# definitions found in the visited files
self.merge_model_specific_imports(self.visited_modules)
# 3. compute the nested (recursive) function and assignment dependencies
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
# We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later
# Note that we may visit several of the same file types, thus we save them per file type, not file
self.imported_objects_per_file = defaultdict(set)
@@ -1200,9 +1205,9 @@ class ModularFileMapper(ModuleMapper):
if object_name in visited_module.functions and object_name not in self.functions:
self.functions[object_name] = visited_module.functions[object_name]
self.added_objects_file_mapping[object_name] = file
dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None)
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
if dependencies is not None:
self.object_recursive_dependency_mapping[object_name] = dependencies
self.object_dependency_mapping[object_name] = dependencies
for dep in dependencies:
if dep not in self.global_nodes:
self.added_objects_file_mapping[dep] = file
@@ -1212,9 +1217,9 @@ class ModularFileMapper(ModuleMapper):
elif object_name in visited_module.assignments and object_name not in self.assignments:
self.assignments[object_name] = visited_module.assignments[object_name]
self.added_objects_file_mapping[object_name] = file
dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None)
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
if dependencies is not None:
self.object_recursive_dependency_mapping[object_name] = dependencies
self.object_dependency_mapping[object_name] = dependencies
for dep in dependencies:
if dep not in self.global_nodes:
self.added_objects_file_mapping[dep] = file
@@ -1222,6 +1227,8 @@ class ModularFileMapper(ModuleMapper):
# Do not forget to re-assign all nodes after the merge
self.global_nodes = {**self.assignments, **self.classes, **self.functions}
# And restric dependencies to those nodes only
self._restrict_dependencies_to_known_entities()
def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]:
"""Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that
@@ -1239,10 +1246,11 @@ class ModularFileMapper(ModuleMapper):
else:
original_dependencies.append(dep)
# Sort all lists according to the order in their respective file
all_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x])
all_dependencies = []
for file, dependencies in other_files_dependencies.items():
sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x])
all_dependencies += sorted_dependencies
all_dependencies += sorted(original_dependencies, key=lambda x: self.start_lines[x])
# Add all original node first, then merged ones (one file at a time)
for dep in all_dependencies:
@@ -1485,7 +1493,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["src/transformers/models/gemma/modular_gemma.py"],
default=["src/transformers/models/gemma2/modular_gemma2.py"],
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)