Fix modular edge case + modular sorting order (#35562)
* look-ahead negation * re add examples by default * Fix the bug in topological sort * Update create_dependency_mapping.py * start adding test * finalize test * more tests * style * style
This commit is contained in:
@@ -43,7 +43,7 @@ class MyNewModelConfig(PretrainedConfig):
|
|||||||
The non-linear activation function (function or string) in the decoder.
|
The non-linear activation function (function or string) in the decoder.
|
||||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||||
The maximum sequence length that this model might ever be used with. MyNewModel 1 supports up to 2048 tokens,
|
The maximum sequence length that this model might ever be used with. MyNewModel 1 supports up to 2048 tokens,
|
||||||
MyNewModel 2 up to 4096, CodeMyNewModel up to 16384.
|
MyNewModel 2 up to 4096, CodeLlama up to 16384.
|
||||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
@@ -110,7 +110,7 @@ class MyNewModelConfig(PretrainedConfig):
|
|||||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
||||||
head_dim (`int`, *optional*):
|
head_dim (`int`, *optional*):
|
||||||
The attention head dimension. If None, it will default to hidden_size // num_heads
|
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import MyNewModelModel, MyNewModelConfig
|
>>> from transformers import MyNewModelModel, MyNewModelConfig
|
||||||
|
|||||||
@@ -597,7 +597,7 @@ class DummyModel(DummyPreTrainedModel):
|
|||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
):
|
):
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||||
return attention_mask
|
return attention_mask
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -597,7 +597,7 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
|
|||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
):
|
):
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||||
return attention_mask
|
return attention_mask
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -602,7 +602,7 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
|
|||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
):
|
):
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||||
return attention_mask
|
return attention_mask
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -519,7 +519,7 @@ class SuperModel(SuperPreTrainedModel):
|
|||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
):
|
):
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||||
return attention_mask
|
return attention_mask
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -612,11 +612,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class DiffLlamaRotaryEmbedding(nn.Module):
|
class DiffLlamaRotaryEmbedding(nn.Module):
|
||||||
def __init__(
|
def __init__(self, config: DiffLlamaConfig, device=None):
|
||||||
self,
|
|
||||||
config: DiffLlamaConfig,
|
|
||||||
device=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
@@ -898,7 +894,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
|
|||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
):
|
):
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||||
return attention_mask
|
return attention_mask
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
63
tests/repo_utils/modular/test_conversion_order.py
Normal file
63
tests/repo_utils/modular/test_conversion_order.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||||
|
sys.path.append(os.path.join(ROOT_DIR, "utils"))
|
||||||
|
|
||||||
|
import create_dependency_mapping # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
# This is equivalent to `all` in the current library state (as of 09/01/2025)
|
||||||
|
MODEL_ROOT = os.path.join("src", "transformers", "models")
|
||||||
|
FILES_TO_PARSE = [
|
||||||
|
os.path.join(MODEL_ROOT, "starcoder2", "modular_starcoder2.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "gemma", "modular_gemma.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "olmo2", "modular_olmo2.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "diffllama", "modular_diffllama.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "granite", "modular_granite.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "gemma2", "modular_gemma2.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "mixtral", "modular_mixtral.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "olmo", "modular_olmo.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "rt_detr", "modular_rt_detr.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "qwen2", "modular_qwen2.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "llava_next_video", "modular_llava_next_video.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "cohere2", "modular_cohere2.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "modernbert", "modular_modernbert.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "colpali", "modular_colpali.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "deformable_detr", "modular_deformable_detr.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "aria", "modular_aria.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "ijepa", "modular_ijepa.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "bamba", "modular_bamba.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "dinov2_with_registers", "modular_dinov2_with_registers.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "instructblipvideo", "modular_instructblipvideo.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "glm", "modular_glm.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "phi", "modular_phi.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "mistral", "modular_mistral.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "phi3", "modular_phi3.py"),
|
||||||
|
os.path.join(MODEL_ROOT, "cohere", "modular_cohere.py"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def appear_after(model1: str, model2: str, priority_list: list[str]) -> bool:
|
||||||
|
"""Return True if `model1` appear after `model2` in `priority_list`."""
|
||||||
|
return priority_list.index(model1) > priority_list.index(model2)
|
||||||
|
|
||||||
|
|
||||||
|
class ConversionOrderTest(unittest.TestCase):
|
||||||
|
def test_conversion_order(self):
|
||||||
|
# Find the order
|
||||||
|
priority_list = create_dependency_mapping.find_priority_list(FILES_TO_PARSE)
|
||||||
|
# Extract just the model names
|
||||||
|
model_priority_list = [file.rsplit("modular_")[-1].replace(".py", "") for file in priority_list]
|
||||||
|
|
||||||
|
# These are based on what the current library order should be (as of 09/01/2025)
|
||||||
|
self.assertTrue(appear_after("mixtral", "mistral", model_priority_list))
|
||||||
|
self.assertTrue(appear_after("gemma2", "gemma", model_priority_list))
|
||||||
|
self.assertTrue(appear_after("starcoder2", "mistral", model_priority_list))
|
||||||
|
self.assertTrue(appear_after("olmo2", "olmo", model_priority_list))
|
||||||
|
self.assertTrue(appear_after("diffllama", "mistral", model_priority_list))
|
||||||
|
self.assertTrue(appear_after("cohere2", "gemma2", model_priority_list))
|
||||||
|
self.assertTrue(appear_after("cohere2", "cohere", model_priority_list))
|
||||||
|
self.assertTrue(appear_after("phi3", "mistral", model_priority_list))
|
||||||
@@ -3,51 +3,29 @@ from collections import defaultdict
|
|||||||
|
|
||||||
|
|
||||||
# Function to perform topological sorting
|
# Function to perform topological sorting
|
||||||
def topological_sort(dependencies):
|
def topological_sort(dependencies: dict):
|
||||||
new_dependencies = {}
|
# Nodes are the name of the models to convert (we only add those to the graph)
|
||||||
graph = defaultdict(list)
|
nodes = {node.rsplit("modular_", 1)[1].replace(".py", "") for node in dependencies.keys()}
|
||||||
|
# This will be a graph from models to convert, to models to convert that should be converted before (as they are a dependency)
|
||||||
|
graph = {}
|
||||||
|
name_mapping = {}
|
||||||
for node, deps in dependencies.items():
|
for node, deps in dependencies.items():
|
||||||
node_name = node.split("/")[-2]
|
node_name = node.rsplit("modular_", 1)[1].replace(".py", "")
|
||||||
for dep in deps:
|
dep_names = {dep.split(".")[-2] for dep in deps}
|
||||||
dep_name = dep.split(".")[-2]
|
dependencies = {dep for dep in dep_names if dep in nodes and dep != node_name}
|
||||||
if dep_name == node_name:
|
graph[node_name] = dependencies
|
||||||
# Skip self dependencies for topological sort as they create cycles
|
name_mapping[node_name] = node
|
||||||
continue
|
|
||||||
if "example" not in node and "auto" not in dep and node_name not in graph[dep_name]:
|
|
||||||
graph[dep_name].append(node_name)
|
|
||||||
new_dependencies[node_name] = node
|
|
||||||
|
|
||||||
# Create a graph and in-degree count for each node
|
sorting_list = []
|
||||||
def filter_one_by_one(filtered_list, reverse):
|
while len(graph) > 0:
|
||||||
if len(reverse) == 0:
|
# Find the nodes with 0 out-degree
|
||||||
return filtered_list
|
leaf_nodes = {node for node in graph if len(graph[node]) == 0}
|
||||||
|
# Add them to the list
|
||||||
|
sorting_list += list(leaf_nodes)
|
||||||
|
# Remove the leafs from the graph (and from the deps of other nodes)
|
||||||
|
graph = {node: deps - leaf_nodes for node, deps in graph.items() if node not in leaf_nodes}
|
||||||
|
|
||||||
graph = defaultdict(list)
|
return [name_mapping[x] for x in sorting_list]
|
||||||
# Build the graph
|
|
||||||
for node, deps in reverse.items():
|
|
||||||
for dep in deps:
|
|
||||||
graph[dep].append(node)
|
|
||||||
|
|
||||||
base_modules = set(reverse.keys()) - set(graph.keys())
|
|
||||||
if base_modules == reverse.keys():
|
|
||||||
# we are at the end
|
|
||||||
return filtered_list + list(graph.keys())
|
|
||||||
to_add = []
|
|
||||||
for k in graph.keys():
|
|
||||||
if len(graph[k]) == 1 and graph[k][0] in base_modules:
|
|
||||||
if graph[k][0] in reverse:
|
|
||||||
del reverse[graph[k][0]]
|
|
||||||
if k not in filtered_list:
|
|
||||||
to_add += [k]
|
|
||||||
for k in base_modules:
|
|
||||||
if k not in filtered_list:
|
|
||||||
to_add += [k]
|
|
||||||
filtered_list += list(to_add)
|
|
||||||
return filter_one_by_one(filtered_list, reverse)
|
|
||||||
|
|
||||||
final_order = filter_one_by_one([], graph)
|
|
||||||
|
|
||||||
return [new_dependencies.get(k) for k in final_order if k in new_dependencies]
|
|
||||||
|
|
||||||
|
|
||||||
# Function to extract class and import info from a file
|
# Function to extract class and import info from a file
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ def get_module_source_from_name(module_name: str) -> str:
|
|||||||
def preserve_case_replace(text, patterns: dict, default_name: str):
|
def preserve_case_replace(text, patterns: dict, default_name: str):
|
||||||
# Create a regex pattern to match all variations
|
# Create a regex pattern to match all variations
|
||||||
regex_pattern = "|".join(re.escape(key) for key in patterns.keys())
|
regex_pattern = "|".join(re.escape(key) for key in patterns.keys())
|
||||||
compiled_regex = re.compile(f"({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL)
|
compiled_regex = re.compile(f"(?<![a-z0-9])({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL)
|
||||||
|
|
||||||
def replace(match):
|
def replace(match):
|
||||||
matched_pattern = match.group(1)
|
matched_pattern = match.group(1)
|
||||||
@@ -1691,9 +1691,13 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.files_to_parse == ["all"]:
|
if args.files_to_parse == ["all"]:
|
||||||
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
||||||
args.files_to_parse += glob.glob("examples/**/modular_*.py", recursive=True)
|
if args.files_to_parse == ["examples"]:
|
||||||
|
args.files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True)
|
||||||
|
|
||||||
for file_name in find_priority_list(args.files_to_parse):
|
priority_list = find_priority_list(args.files_to_parse)
|
||||||
|
assert len(priority_list) == len(args.files_to_parse), "Some files will not be converted"
|
||||||
|
|
||||||
|
for file_name in priority_list:
|
||||||
print(f"Converting {file_name} to a single model single file format")
|
print(f"Converting {file_name} to a single model single file format")
|
||||||
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
|
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
|
||||||
converted_files = convert_modular_file(file_name)
|
converted_files = convert_modular_file(file_name)
|
||||||
|
|||||||
Reference in New Issue
Block a user