diff --git a/conftest.py b/conftest.py
index 67e6eddfb8..e07103e4c3 100644
--- a/conftest.py
+++ b/conftest.py
@@ -29,7 +29,6 @@ from transformers.testing_utils import HfDoctestModule, HfDocTestParser
NOT_DEVICE_TESTS = {
"test_tokenization",
"test_tokenization_mistral_common",
- "test_processor",
"test_processing",
"test_beam_constraints",
"test_configuration_utils",
diff --git a/src/transformers/commands/add_new_model_like.py b/src/transformers/commands/add_new_model_like.py
index cbb751ad75..74330b8d3c 100644
--- a/src/transformers/commands/add_new_model_like.py
+++ b/src/transformers/commands/add_new_model_like.py
@@ -11,1562 +11,569 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-# 1. Standard library
import difflib
-import json
import os
import re
+import subprocess
+import textwrap
from argparse import ArgumentParser, Namespace
-from dataclasses import dataclass
from datetime import date
-from itertools import chain
from pathlib import Path
-from re import Pattern
from typing import Any, Callable, Optional, Union
-import yaml
-
-from ..models import auto as auto_module
-from ..models.auto.configuration_auto import model_type_to_module_name
-from ..utils import (
- is_flax_available,
- is_tf_available,
- is_torch_available,
- logging,
-)
+from ..models.auto.configuration_auto import CONFIG_MAPPING_NAMES, MODEL_NAMES_MAPPING
+from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
+from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES
+from ..models.auto.processing_auto import PROCESSOR_MAPPING_NAMES
+from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
+from ..models.auto.video_processing_auto import VIDEO_PROCESSOR_MAPPING_NAMES
+from ..utils import is_libcst_available
from . import BaseTransformersCLICommand
from .add_fast_image_processor import add_fast_image_processor
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+# We protect this import to avoid requiring it for all `transformers` CLI commands - however it is actually
+# strictly required for this one (we need it both for modular and for the following Visitor)
+if is_libcst_available():
+ import libcst as cst
+ from libcst import CSTVisitor
+ from libcst import matchers as m
+
+ class ClassFinder(CSTVisitor):
+ """
+ A visitor to find all classes in a python module.
+ """
+
+ def __init__(self):
+ self.classes: list = []
+ self.public_classes: list = []
+ self.is_in_class = False
+
+ def visit_ClassDef(self, node: cst.ClassDef) -> None:
+ """Record class names. We assume classes always only appear at top-level (i.e. no class definition in function or similar)"""
+ self.classes.append(node.name.value)
+ self.is_in_class = True
+
+ def leave_ClassDef(self, node: cst.ClassDef):
+ self.is_in_class = False
+
+ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine):
+ """Record all public classes inside the `__all__` assignment."""
+ simple_top_level_assign_structure = m.SimpleStatementLine(
+ body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])]
+ )
+ if not self.is_in_class and m.matches(node, simple_top_level_assign_structure):
+ assigned_variable = node.body[0].targets[0].target.value
+ if assigned_variable == "__all__":
+ elements = node.body[0].value.elements
+ self.public_classes = [element.value.value for element in elements]
CURRENT_YEAR = date.today().year
TRANSFORMERS_PATH = Path(__file__).parent.parent
REPO_PATH = TRANSFORMERS_PATH.parent.parent
+COPYRIGHT = f"""
+# coding=utf-8
+# Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""".lstrip()
-@dataclass
-class ModelPatterns:
+
+class ModelInfos(object):
"""
- Holds the basic information about a new model for the add-new-model-like command.
-
- Args:
- model_name (`str`): The model name.
- checkpoint (`str`): The checkpoint to use for doc examples.
- model_type (`str`, *optional*):
- The model type, the identifier used internally in the library like `bert` or `xlm-roberta`. Will default to
- `model_name` lowercased with spaces replaced with minuses (-).
- model_lower_cased (`str`, *optional*):
- The lowercased version of the model name, to use for the module name or function names. Will default to
- `model_name` lowercased with spaces and minuses replaced with underscores.
- model_camel_cased (`str`, *optional*):
- The camel-cased version of the model name, to use for the class names. Will default to `model_name`
- camel-cased (with spaces and minuses both considered as word separators.
- model_upper_cased (`str`, *optional*):
- The uppercased version of the model name, to use for the constant names. Will default to `model_name`
- uppercased with spaces and minuses replaced with underscores.
- config_class (`str`, *optional*):
- The tokenizer class associated with this model. Will default to `"{model_camel_cased}Config"`.
- tokenizer_class (`str`, *optional*):
- The tokenizer class associated with this model (leave to `None` for models that don't use a tokenizer).
- image_processor_class (`str`, *optional*):
- The image processor class associated with this model (leave to `None` for models that don't use an image
- processor).
- image_processor_fast_class (`str`, *optional*):
- The fast image processor class associated with this model (leave to `None` for models that don't use a fast
- image processor).
- feature_extractor_class (`str`, *optional*):
- The feature extractor class associated with this model (leave to `None` for models that don't use a feature
- extractor).
- processor_class (`str`, *optional*):
- The processor class associated with this model (leave to `None` for models that don't use a processor).
+ Retrieve the basic informations about an existing model classes.
"""
- model_name: str
- checkpoint: str
- model_type: Optional[str] = None
- model_lower_cased: Optional[str] = None
- model_camel_cased: Optional[str] = None
- model_upper_cased: Optional[str] = None
- config_class: Optional[str] = None
- tokenizer_class: Optional[str] = None
- image_processor_class: Optional[str] = None
- image_processor_fast_class: Optional[str] = None
- feature_extractor_class: Optional[str] = None
- processor_class: Optional[str] = None
+ def __init__(self, lowercase_name: str):
+ # Just to make sure it's indeed lowercase
+ self.lowercase_name = lowercase_name.lower().replace(" ", "_").replace("-", "_")
+ if self.lowercase_name not in CONFIG_MAPPING_NAMES:
+ self.lowercase_name.replace("_", "-")
+ if self.lowercase_name not in CONFIG_MAPPING_NAMES:
+ raise ValueError(f"{lowercase_name} is not a valid model name")
- def __post_init__(self):
- if self.model_type is None:
- self.model_type = self.model_name.lower().replace(" ", "-")
- if self.model_lower_cased is None:
- self.model_lower_cased = self.model_name.lower().replace(" ", "_").replace("-", "_")
- if self.model_camel_cased is None:
- # Split the model name on - and space
- words = self.model_name.split(" ")
- words = list(chain(*[w.split("-") for w in words]))
- # Make sure each word is capitalized
- words = [w[0].upper() + w[1:] for w in words]
- self.model_camel_cased = "".join(words)
- if self.model_upper_cased is None:
- self.model_upper_cased = self.model_name.upper().replace(" ", "_").replace("-", "_")
- if self.config_class is None:
- self.config_class = f"{self.model_camel_cased}Config"
+ self.paper_name = MODEL_NAMES_MAPPING[self.lowercase_name]
+ self.config_class = CONFIG_MAPPING_NAMES[self.lowercase_name]
+ self.camelcase_name = self.config_class.replace("Config", "")
-
-ATTRIBUTE_TO_PLACEHOLDER = {
- "config_class": "[CONFIG_CLASS]",
- "tokenizer_class": "[TOKENIZER_CLASS]",
- "image_processor_class": "[IMAGE_PROCESSOR_CLASS]",
- "image_processor_fast_class": "[IMAGE_PROCESSOR_FAST_CLASS]",
- "feature_extractor_class": "[FEATURE_EXTRACTOR_CLASS]",
- "processor_class": "[PROCESSOR_CLASS]",
- "checkpoint": "[CHECKPOINT]",
- "model_type": "[MODEL_TYPE]",
- "model_upper_cased": "[MODEL_UPPER_CASED]",
- "model_camel_cased": "[MODEL_CAMELCASED]",
- "model_lower_cased": "[MODEL_LOWER_CASED]",
- "model_name": "[MODEL_NAME]",
-}
-
-
-def is_empty_line(line: str) -> bool:
- """
- Determines whether a line is empty or not.
- """
- return len(line) == 0 or line.isspace()
-
-
-def find_indent(line: str) -> int:
- """
- Returns the number of spaces that start a line indent.
- """
- search = re.search(r"^(\s*)(?:\S|$)", line)
- if search is None:
- return 0
- return len(search.groups()[0])
-
-
-def parse_module_content(content: str) -> list[str]:
- """
- Parse the content of a module in the list of objects it defines.
-
- Args:
- content (`str`): The content to parse
-
- Returns:
- `list[str]`: The list of objects defined in the module.
- """
- objects = []
- current_object = []
- lines = content.split("\n")
- # Doc-styler takes everything between two triple quotes in docstrings, so we need a fake """ here to go with this.
- end_markers = [")", "]", "}", '"""']
-
- for line in lines:
- # End of an object
- is_valid_object = len(current_object) > 0
- if is_valid_object and len(current_object) == 1:
- is_valid_object = not current_object[0].startswith("# Copied from")
- if not is_empty_line(line) and find_indent(line) == 0 and is_valid_object:
- # Closing parts should be included in current object
- if line in end_markers:
- current_object.append(line)
- objects.append("\n".join(current_object))
- current_object = []
- else:
- objects.append("\n".join(current_object))
- current_object = [line]
- else:
- current_object.append(line)
-
- # Add last object
- if len(current_object) > 0:
- objects.append("\n".join(current_object))
-
- return objects
-
-
-def extract_block(content: str, indent_level: int = 0) -> str:
- """Return the first block in `content` with the indent level `indent_level`.
-
- The first line in `content` should be indented at `indent_level` level, otherwise an error will be thrown.
-
- This method will immediately stop the search when a (non-empty) line with indent level less than `indent_level` is
- encountered.
-
- Args:
- content (`str`): The content to parse
- indent_level (`int`, *optional*, default to 0): The indent level of the blocks to search for
-
- Returns:
- `str`: The first block in `content` with the indent level `indent_level`.
- """
- current_object = []
- lines = content.split("\n")
- # Doc-styler takes everything between two triple quotes in docstrings, so we need a fake """ here to go with this.
- end_markers = [")", "]", "}", '"""']
-
- for idx, line in enumerate(lines):
- if idx == 0 and indent_level > 0 and not is_empty_line(line) and find_indent(line) != indent_level:
- raise ValueError(
- f"When `indent_level > 0`, the first line in `content` should have indent level {indent_level}. Got "
- f"{find_indent(line)} instead."
+ # Get tokenizer class
+ if self.lowercase_name in TOKENIZER_MAPPING_NAMES:
+ self.tokenizer_class, self.fast_tokenizer_class = TOKENIZER_MAPPING_NAMES[self.lowercase_name]
+ self.fast_tokenizer_class = (
+ None if self.fast_tokenizer_class == "PreTrainedTokenizerFast" else self.fast_tokenizer_class
)
-
- if find_indent(line) < indent_level and not is_empty_line(line):
- break
-
- # End of an object
- is_valid_object = len(current_object) > 0
- if (
- not is_empty_line(line)
- and not line.endswith(":")
- and find_indent(line) == indent_level
- and is_valid_object
- ):
- # Closing parts should be included in current object
- if line.lstrip() in end_markers:
- current_object.append(line)
- return "\n".join(current_object)
else:
- current_object.append(line)
+ self.tokenizer_class, self.fast_tokenizer_class = None, None
- # Add last object
- if len(current_object) > 0:
- return "\n".join(current_object)
+ self.image_processor_class, self.fast_image_processor_class = IMAGE_PROCESSOR_MAPPING_NAMES.get(
+ self.lowercase_name, (None, None)
+ )
+ self.video_processor_class = VIDEO_PROCESSOR_MAPPING_NAMES.get(self.lowercase_name, None)
+ self.feature_extractor_class = FEATURE_EXTRACTOR_MAPPING_NAMES.get(self.lowercase_name, None)
+ self.processor_class = PROCESSOR_MAPPING_NAMES.get(self.lowercase_name, None)
-def add_content_to_text(
- text: str,
- content: str,
- add_after: Optional[Union[str, Pattern]] = None,
- add_before: Optional[Union[str, Pattern]] = None,
- exact_match: bool = False,
-) -> str:
- """
- A utility to add some content inside a given text.
-
- Args:
- text (`str`): The text in which we want to insert some content.
- content (`str`): The content to add.
- add_after (`str` or `Pattern`):
- The pattern to test on a line of `text`, the new content is added after the first instance matching it.
- add_before (`str` or `Pattern`):
- The pattern to test on a line of `text`, the new content is added before the first instance matching it.
- exact_match (`bool`, *optional*, defaults to `False`):
- A line is considered a match with `add_after` or `add_before` if it matches exactly when `exact_match=True`,
- otherwise, if `add_after`/`add_before` is present in the line.
-
-
-
- The arguments `add_after` and `add_before` are mutually exclusive, and one exactly needs to be provided.
-
-
-
- Returns:
- `str`: The text with the new content added if a match was found.
- """
- if add_after is None and add_before is None:
- raise ValueError("You need to pass either `add_after` or `add_before`")
- if add_after is not None and add_before is not None:
- raise ValueError("You can't pass both `add_after` or `add_before`")
- pattern = add_after if add_before is None else add_before
-
- def this_is_the_line(line):
- if isinstance(pattern, Pattern):
- return pattern.search(line) is not None
- elif exact_match:
- return pattern == line
- else:
- return pattern in line
-
- new_lines = []
- for line in text.split("\n"):
- if this_is_the_line(line):
- if add_before is not None:
- new_lines.append(content)
- new_lines.append(line)
- if add_after is not None:
- new_lines.append(content)
- else:
- new_lines.append(line)
-
- return "\n".join(new_lines)
-
-
-def add_content_to_file(
- file_name: Union[str, os.PathLike],
- content: str,
- add_after: Optional[Union[str, Pattern]] = None,
- add_before: Optional[Union[str, Pattern]] = None,
- exact_match: bool = False,
-):
+def add_content_to_file(file_name: Union[str, os.PathLike], new_content: str, add_after: str):
"""
A utility to add some content inside a given file.
Args:
- file_name (`str` or `os.PathLike`): The name of the file in which we want to insert some content.
- content (`str`): The content to add.
- add_after (`str` or `Pattern`):
- The pattern to test on a line of `text`, the new content is added after the first instance matching it.
- add_before (`str` or `Pattern`):
- The pattern to test on a line of `text`, the new content is added before the first instance matching it.
- exact_match (`bool`, *optional*, defaults to `False`):
- A line is considered a match with `add_after` or `add_before` if it matches exactly when `exact_match=True`,
- otherwise, if `add_after`/`add_before` is present in the line.
-
-
-
- The arguments `add_after` and `add_before` are mutually exclusive, and one exactly needs to be provided.
-
-
+ file_name (`str` or `os.PathLike`):
+ The name of the file in which we want to insert some content.
+ new_content (`str`):
+ The content to add.
+ add_after (`str`):
+ The new content is added just after the first instance matching it.
"""
with open(file_name, "r", encoding="utf-8") as f:
old_content = f.read()
- new_content = add_content_to_text(
- old_content, content, add_after=add_after, add_before=add_before, exact_match=exact_match
- )
+ before, after = old_content.split(add_after, 1)
+ new_content = before + add_after + new_content + after
with open(file_name, "w", encoding="utf-8") as f:
f.write(new_content)
-def replace_model_patterns(
- text: str, old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns
-) -> tuple[str, str]:
- """
- Replace all patterns present in a given text.
-
- Args:
- text (`str`): The text to treat.
- old_model_patterns (`ModelPatterns`): The patterns for the old model.
- new_model_patterns (`ModelPatterns`): The patterns for the new model.
-
- Returns:
- `Tuple(str, str)`: A tuple of with the treated text and the replacement actually done in it.
- """
- # The order is crucially important as we will check and replace in that order. For instance the config probably
- # contains the camel-cased named, but will be treated before.
- attributes_to_check = ["config_class"]
- # Add relevant preprocessing classes
- for attr in [
- "tokenizer_class",
- "image_processor_class",
- "image_processor_fast_class",
- "feature_extractor_class",
- "processor_class",
- ]:
- if getattr(old_model_patterns, attr) is not None and getattr(new_model_patterns, attr) is not None:
- attributes_to_check.append(attr)
-
- # Special cases for checkpoint and model_type
- if old_model_patterns.checkpoint not in [old_model_patterns.model_type, old_model_patterns.model_lower_cased]:
- attributes_to_check.append("checkpoint")
- if old_model_patterns.model_type != old_model_patterns.model_lower_cased:
- attributes_to_check.append("model_type")
- else:
- text = re.sub(
- rf'(\s*)model_type = "{old_model_patterns.model_type}"',
- r'\1model_type = "[MODEL_TYPE]"',
- text,
- )
-
- # Special case when the model camel cased and upper cased names are the same for the old model (like for GPT2) but
- # not the new one. We can't just do a replace in all the text and will need a special regex
- if old_model_patterns.model_upper_cased == old_model_patterns.model_camel_cased:
- old_model_value = old_model_patterns.model_upper_cased
- if re.search(rf"{old_model_value}_[A-Z_]*[^A-Z_]", text) is not None:
- text = re.sub(rf"{old_model_value}([A-Z_]*)([^a-zA-Z_])", r"[MODEL_UPPER_CASED]\1\2", text)
- else:
- attributes_to_check.append("model_upper_cased")
-
- attributes_to_check.extend(["model_camel_cased", "model_lower_cased", "model_name"])
-
- # Now let's replace every other attribute by their placeholder
- for attr in attributes_to_check:
- text = text.replace(getattr(old_model_patterns, attr), ATTRIBUTE_TO_PLACEHOLDER[attr])
-
- # Finally we can replace the placeholder byt the new values.
- replacements = []
- for attr, placeholder in ATTRIBUTE_TO_PLACEHOLDER.items():
- if placeholder in text:
- replacements.append((getattr(old_model_patterns, attr), getattr(new_model_patterns, attr)))
- text = text.replace(placeholder, getattr(new_model_patterns, attr))
-
- # If we have two inconsistent replacements, we don't return anything (ex: GPT2->GPT_NEW and GPT2->GPTNew)
- old_replacement_values = [old for old, new in replacements]
- if len(set(old_replacement_values)) != len(old_replacement_values):
- return text, ""
-
- replacements = simplify_replacements(replacements)
- replacements = [f"{old}->{new}" for old, new in replacements]
- return text, ",".join(replacements)
-
-
-def simplify_replacements(replacements):
- """
- Simplify a list of replacement patterns to make sure there are no needless ones.
-
- For instance in the sequence "Bert->BertNew, BertConfig->BertNewConfig, bert->bert_new", the replacement
- "BertConfig->BertNewConfig" is implied by "Bert->BertNew" so not needed.
-
- Args:
- replacements (`list[tuple[str, str]]`): List of patterns (old, new)
-
- Returns:
- `list[tuple[str, str]]`: The list of patterns simplified.
- """
- if len(replacements) <= 1:
- # Nothing to simplify
- return replacements
-
- # Next let's sort replacements by length as a replacement can only "imply" another replacement if it's shorter.
- replacements.sort(key=lambda x: len(x[0]))
-
- idx = 0
- while idx < len(replacements):
- old, new = replacements[idx]
- # Loop through all replacements after
- j = idx + 1
- while j < len(replacements):
- old_2, new_2 = replacements[j]
- # If the replacement is implied by the current one, we can drop it.
- if old_2.replace(old, new) == new_2:
- replacements.pop(j)
- else:
- j += 1
- idx += 1
-
- return replacements
-
-
-def get_module_from_file(module_file: Union[str, os.PathLike]) -> str:
- """
- Returns the module name corresponding to a module file.
- """
- full_module_path = Path(module_file).absolute()
- module_parts = full_module_path.with_suffix("").parts
-
- # Find the first part named transformers, starting from the end.
- idx = len(module_parts) - 1
- while idx >= 0 and module_parts[idx] != "transformers":
- idx -= 1
- if idx < 0:
- raise ValueError(f"{module_file} is not a transformers module.")
-
- return ".".join(module_parts[idx:])
-
-
-SPECIAL_PATTERNS = {
- "_CHECKPOINT_FOR_DOC =": "checkpoint",
- "_CONFIG_FOR_DOC =": "config_class",
- "_TOKENIZER_FOR_DOC =": "tokenizer_class",
- "_IMAGE_PROCESSOR_FOR_DOC =": "image_processor_class",
- "_FEAT_EXTRACTOR_FOR_DOC =": "feature_extractor_class",
- "_PROCESSOR_FOR_DOC =": "processor_class",
-}
-
-
-_re_class_func = re.compile(r"^(?:class|def)\s+([^\s:\(]+)\s*(?:\(|\:)", flags=re.MULTILINE)
-
-
-def remove_attributes(obj, target_attr):
- """Remove `target_attr` in `obj`."""
- lines = obj.split(os.linesep)
-
- target_idx = None
- for idx, line in enumerate(lines):
- # search for assignment
- if line.lstrip().startswith(f"{target_attr} = "):
- target_idx = idx
- break
- # search for function/method definition
- elif line.lstrip().startswith(f"def {target_attr}("):
- target_idx = idx
- break
-
- # target not found
- if target_idx is None:
- return obj
-
- line = lines[target_idx]
- indent_level = find_indent(line)
- # forward pass to find the ending of the block (including empty lines)
- parsed = extract_block("\n".join(lines[target_idx:]), indent_level)
- num_lines = len(parsed.split("\n"))
- for idx in range(num_lines):
- lines[target_idx + idx] = None
-
- # backward pass to find comments or decorator
- for idx in range(target_idx - 1, -1, -1):
- line = lines[idx]
- if (line.lstrip().startswith("#") or line.lstrip().startswith("@")) and find_indent(line) == indent_level:
- lines[idx] = None
- else:
- break
-
- new_obj = os.linesep.join([x for x in lines if x is not None])
-
- return new_obj
-
-
-def duplicate_module(
- module_file: Union[str, os.PathLike],
- old_model_patterns: ModelPatterns,
- new_model_patterns: ModelPatterns,
- dest_file: Optional[str] = None,
- add_copied_from: bool = True,
- attrs_to_remove: Optional[list[str]] = None,
+def add_model_to_auto_mappings(
+ old_model_infos: ModelInfos,
+ new_lowercase_name: str,
+ new_model_paper_name: str,
+ filenames_to_add: list[tuple[str, bool]],
):
"""
- Create a new module from an existing one and adapting all function and classes names from old patterns to new ones.
+ Add a model to all the relevant mappings in the auto module.
Args:
- module_file (`str` or `os.PathLike`): Path to the module to duplicate.
- old_model_patterns (`ModelPatterns`): The patterns for the old model.
- new_model_patterns (`ModelPatterns`): The patterns for the new model.
- dest_file (`str` or `os.PathLike`, *optional*): Path to the new module.
- add_copied_from (`bool`, *optional*, defaults to `True`):
- Whether or not to add `# Copied from` statements in the duplicated module.
+ old_model_infos (`ModelInfos`):
+ The structure containing the class informations of the old model.
+ new_lowercase_name (`str`):
+ The new lowercase model name.
+ new_model_paper_name (`str`):
+ The fully cased name (as in the official paper name) of the new model.
+ filenames_to_add (`list[tuple[str, bool]]`):
+ A list of tuples of all potential filenames to add for a new model, along a boolean flag describing if we
+ should add this file or not. For example, [(`modeling_xxx.px`, True), (`configuration_xxx.py`, True), (`tokenization_xxx.py`, False),...]
"""
- if dest_file is None:
- dest_file = str(module_file).replace(
- old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
- )
-
- with open(module_file, "r", encoding="utf-8") as f:
- content = f.read()
-
- content = re.sub(r"# Copyright (\d+)\s", f"# Copyright {CURRENT_YEAR} ", content)
- objects = parse_module_content(content)
-
- # Loop and treat all objects
- new_objects = []
- for obj in objects:
- special_pattern = False
- for pattern, attr in SPECIAL_PATTERNS.items():
- if pattern in obj:
- obj = obj.replace(getattr(old_model_patterns, attr), getattr(new_model_patterns, attr))
- new_objects.append(obj)
- special_pattern = True
- break
-
- if special_pattern:
- continue
-
- # Regular classes functions
- old_obj = obj
- obj, replacement = replace_model_patterns(obj, old_model_patterns, new_model_patterns)
- has_copied_from = re.search(r"^#\s+Copied from", obj, flags=re.MULTILINE) is not None
- if add_copied_from and not has_copied_from and _re_class_func.search(obj) is not None and len(replacement) > 0:
- # Copied from statement must be added just before the class/function definition, which may not be the
- # first line because of decorators.
- module_name = get_module_from_file(module_file)
- old_object_name = _re_class_func.search(old_obj).groups()[0]
- obj = add_content_to_text(
- obj, f"# Copied from {module_name}.{old_object_name} with {replacement}", add_before=_re_class_func
- )
- # In all cases, we remove Copied from statement with indent on methods.
- obj = re.sub("\n[ ]+# Copied from [^\n]*\n", "\n", obj)
-
- new_objects.append(obj)
-
- content = "\n".join(new_objects)
- # Remove some attributes that we don't want to copy to the new file(s)
- if attrs_to_remove is not None:
- for attr in attrs_to_remove:
- content = remove_attributes(content, target_attr=attr)
-
- with open(dest_file, "w", encoding="utf-8") as f:
- f.write(content)
-
-
-def filter_framework_files(
- files: list[Union[str, os.PathLike]], frameworks: Optional[list[str]] = None
-) -> list[Union[str, os.PathLike]]:
- """
- Filter a list of files to only keep the ones corresponding to a list of frameworks.
-
- Args:
- files (`list[Union[str, os.PathLike]]`): The list of files to filter.
- frameworks (`list[str]`, *optional*): The list of allowed frameworks.
-
- Returns:
- `list[Union[str, os.PathLike]]`: The list of filtered files.
- """
- if frameworks is None:
- frameworks = get_default_frameworks()
-
- framework_to_file = {}
- others = []
- for f in files:
- parts = Path(f).name.split("_")
- if "modeling" not in parts:
- others.append(f)
- continue
- if "tf" in parts:
- framework_to_file["tf"] = f
- elif "flax" in parts:
- framework_to_file["flax"] = f
- else:
- framework_to_file["pt"] = f
-
- return [framework_to_file[f] for f in frameworks if f in framework_to_file] + others
-
-
-def get_model_files(model_type: str, frameworks: Optional[list[str]] = None) -> dict[str, Union[Path, list[Path]]]:
- """
- Retrieves all the files associated to a model.
-
- Args:
- model_type (`str`): A valid model type (like "bert" or "gpt2")
- frameworks (`list[str]`, *optional*):
- If passed, will only keep the model files corresponding to the passed frameworks.
-
- Returns:
- `dict[str, Union[Path, list[Path]]]`: A dictionary with the following keys:
- - **doc_file** -- The documentation file for the model.
- - **model_files** -- All the files in the model module.
- - **test_files** -- The test files for the model.
- """
- module_name = model_type_to_module_name(model_type)
-
- model_module = TRANSFORMERS_PATH / "models" / module_name
- model_files = list(model_module.glob("*.py"))
- model_files = filter_framework_files(model_files, frameworks=frameworks)
-
- doc_file = REPO_PATH / "docs" / "source" / "en" / "model_doc" / f"{model_type}.md"
-
- # Basic pattern for test files
- test_files = [
- f"test_modeling_{module_name}.py",
- f"test_modeling_tf_{module_name}.py",
- f"test_modeling_flax_{module_name}.py",
- f"test_tokenization_{module_name}.py",
- f"test_image_processing_{module_name}.py",
- f"test_feature_extraction_{module_name}.py",
- f"test_processor_{module_name}.py",
+ new_cased_name = "".join(x.title() for x in new_lowercase_name.replace("-", "_").split("_"))
+ old_lowercase_name = old_model_infos.lowercase_name
+ old_cased_name = old_model_infos.camelcase_name
+ filenames_to_add = [
+ (filename.replace(old_lowercase_name, "auto"), to_add) for filename, to_add in filenames_to_add[1:]
]
- test_files = filter_framework_files(test_files, frameworks=frameworks)
- # Add the test directory
- test_files = [REPO_PATH / "tests" / "models" / module_name / f for f in test_files]
- # Filter by existing files
- test_files = [f for f in test_files if f.exists()]
+ # fast tokenizer/image processor have the same auto mappings as normal ones
+ corrected_filenames_to_add = []
+ for file, to_add in filenames_to_add:
+ if re.search(r"(?:tokenization)|(?:image_processing)_auto_fast.py", file):
+ previous_file, previous_to_add = corrected_filenames_to_add[-1]
+ corrected_filenames_to_add[-1] = (previous_file, previous_to_add or to_add)
+ else:
+ corrected_filenames_to_add.append((file, to_add))
- return {"doc_file": doc_file, "model_files": model_files, "module_name": module_name, "test_files": test_files}
-
-
-_re_checkpoint_in_config = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)")
-
-
-def find_base_model_checkpoint(
- model_type: str, model_files: Optional[dict[str, Union[Path, list[Path]]]] = None
-) -> str:
- """
- Finds the model checkpoint used in the docstrings for a given model.
-
- Args:
- model_type (`str`): A valid model type (like "bert" or "gpt2")
- model_files (`dict[str, Union[Path, list[Path]]`, *optional*):
- The files associated to `model_type`. Can be passed to speed up the function, otherwise will be computed.
-
- Returns:
- `str`: The checkpoint used.
- """
- if model_files is None:
- model_files = get_model_files(model_type)
- module_files = model_files["model_files"]
- for fname in module_files:
- # After the @auto_docstring refactor, we expect the checkpoint to be in the configuration file's docstring
- if "configuration" not in str(fname):
- continue
-
- with open(fname, "r", encoding="utf-8") as f:
- content = f.read()
- if _re_checkpoint_in_config.search(content) is not None:
- checkpoint = _re_checkpoint_in_config.search(content).groups()[0]
- # Remove quotes
- checkpoint = checkpoint.replace('"', "")
- checkpoint = checkpoint.replace("'", "")
- return checkpoint
-
- # TODO: Find some kind of fallback if there is no _CHECKPOINT_FOR_DOC in any of the modeling file.
- return ""
-
-
-def get_default_frameworks():
- """
- Returns the list of frameworks (PyTorch, TensorFlow, Flax) that are installed in the environment.
- """
- frameworks = []
- if is_torch_available():
- frameworks.append("pt")
- if is_tf_available():
- frameworks.append("tf")
- if is_flax_available():
- frameworks.append("flax")
- return frameworks
-
-
-_re_model_mapping = re.compile("MODEL_([A-Z_]*)MAPPING_NAMES")
-
-
-def retrieve_model_classes(model_type: str, frameworks: Optional[list[str]] = None) -> dict[str, list[str]]:
- """
- Retrieve the model classes associated to a given model.
-
- Args:
- model_type (`str`): A valid model type (like "bert" or "gpt2")
- frameworks (`list[str]`, *optional*):
- The frameworks to look for. Will default to `["pt", "tf", "flax"]`, passing a smaller list will restrict
- the classes returned.
-
- Returns:
- `dict[str, list[str]]`: A dictionary with one key per framework and the list of model classes associated to
- that framework as values.
- """
- if frameworks is None:
- frameworks = get_default_frameworks()
-
- modules = {
- "pt": auto_module.modeling_auto if is_torch_available() else None,
- "tf": auto_module.modeling_tf_auto if is_tf_available() else None,
- "flax": auto_module.modeling_flax_auto if is_flax_available() else None,
- }
-
- model_classes = {}
- for framework in frameworks:
- new_model_classes = []
- if modules[framework] is None:
- raise ValueError(f"You selected {framework} in the frameworks, but it is not installed.")
- model_mappings = [attr for attr in dir(modules[framework]) if _re_model_mapping.search(attr) is not None]
- for model_mapping_name in model_mappings:
- model_mapping = getattr(modules[framework], model_mapping_name)
- if model_type in model_mapping:
- new_model_classes.append(model_mapping[model_type])
-
- if len(new_model_classes) > 0:
- # Remove duplicates
- model_classes[framework] = list(set(new_model_classes))
-
- return model_classes
-
-
-def retrieve_info_for_model(model_type, frameworks: Optional[list[str]] = None):
- """
- Retrieves all the information from a given model_type.
-
- Args:
- model_type (`str`): A valid model type (like "bert" or "gpt2")
- frameworks (`list[str]`, *optional*):
- If passed, will only keep the info corresponding to the passed frameworks.
-
- Returns:
- `Dict`: A dictionary with the following keys:
- - **frameworks** (`list[str]`): The list of frameworks that back this model type.
- - **model_classes** (`dict[str, list[str]]`): The model classes implemented for that model type.
- - **model_files** (`dict[str, Union[Path, list[Path]]]`): The files associated with that model type.
- - **model_patterns** (`ModelPatterns`): The various patterns for the model.
- """
- if model_type not in auto_module.MODEL_NAMES_MAPPING:
- raise ValueError(f"{model_type} is not a valid model type.")
-
- model_name = auto_module.MODEL_NAMES_MAPPING[model_type]
- config_class = auto_module.configuration_auto.CONFIG_MAPPING_NAMES[model_type]
- if model_type in auto_module.tokenization_auto.TOKENIZER_MAPPING_NAMES:
- tokenizer_classes = auto_module.tokenization_auto.TOKENIZER_MAPPING_NAMES[model_type]
- tokenizer_class = tokenizer_classes[0] if tokenizer_classes[0] is not None else tokenizer_classes[1]
- else:
- tokenizer_class = None
- image_processor_classes = auto_module.image_processing_auto.IMAGE_PROCESSOR_MAPPING_NAMES.get(model_type, None)
- if isinstance(image_processor_classes, tuple):
- image_processor_class, image_processor_fast_class = image_processor_classes
- else:
- image_processor_class = image_processor_classes
- image_processor_fast_class = None
- feature_extractor_class = auto_module.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES.get(model_type, None)
- processor_class = auto_module.processing_auto.PROCESSOR_MAPPING_NAMES.get(model_type, None)
-
- model_files = get_model_files(model_type, frameworks=frameworks)
- model_camel_cased = config_class.replace("Config", "")
-
- available_frameworks = []
- for fname in model_files["model_files"]:
- if "modeling_tf" in str(fname):
- available_frameworks.append("tf")
- elif "modeling_flax" in str(fname):
- available_frameworks.append("flax")
- elif "modeling" in str(fname):
- available_frameworks.append("pt")
-
- if frameworks is None:
- frameworks = get_default_frameworks()
-
- frameworks = [f for f in frameworks if f in available_frameworks]
-
- model_classes = retrieve_model_classes(model_type, frameworks=frameworks)
-
- model_upper_cased = model_camel_cased.upper()
- model_patterns = ModelPatterns(
- model_name,
- checkpoint=find_base_model_checkpoint(model_type, model_files=model_files),
- model_type=model_type,
- model_camel_cased=model_camel_cased,
- model_lower_cased=model_files["module_name"],
- model_upper_cased=model_upper_cased,
- config_class=config_class,
- tokenizer_class=tokenizer_class,
- image_processor_class=image_processor_class,
- image_processor_fast_class=image_processor_fast_class,
- feature_extractor_class=feature_extractor_class,
- processor_class=processor_class,
+ # Add the config mappings directly as the handling for config is a bit different
+ add_content_to_file(
+ TRANSFORMERS_PATH / "models" / "auto" / "configuration_auto.py",
+ new_content=f' ("{new_lowercase_name}", "{new_cased_name}Config"),\n',
+ add_after="CONFIG_MAPPING_NAMES = OrderedDict[str, str](\n [\n # Add configs here\n",
+ )
+ add_content_to_file(
+ TRANSFORMERS_PATH / "models" / "auto" / "configuration_auto.py",
+ new_content=f' ("{new_lowercase_name}", "{new_model_paper_name}"),\n',
+ add_after="MODEL_NAMES_MAPPING = OrderedDict[str, str](\n [\n # Add full (and cased) model names here\n",
)
- return {
- "frameworks": frameworks,
- "model_classes": model_classes,
- "model_files": model_files,
- "model_patterns": model_patterns,
- }
-
-
-def clean_frameworks_in_init(
- init_file: Union[str, os.PathLike], frameworks: Optional[list[str]] = None, keep_processing: bool = True
-):
- """
- Removes all the import lines that don't belong to a given list of frameworks or concern tokenizers/feature
- extractors/image processors/processors in an init.
-
- Args:
- init_file (`str` or `os.PathLike`): The path to the init to treat.
- frameworks (`list[str]`, *optional*):
- If passed, this will remove all imports that are subject to a framework not in frameworks
- keep_processing (`bool`, *optional*, defaults to `True`):
- Whether or not to keep the preprocessing (tokenizer, feature extractor, image processor, processor) imports
- in the init.
- """
- if frameworks is None:
- frameworks = get_default_frameworks()
-
- names = {"pt": "torch"}
- to_remove = [names.get(f, f) for f in ["pt", "tf", "flax"] if f not in frameworks]
- if not keep_processing:
- to_remove.extend(["sentencepiece", "tokenizers", "vision"])
-
- if len(to_remove) == 0:
- # Nothing to do
- return
-
- remove_pattern = "|".join(to_remove)
- re_conditional_imports = re.compile(rf"^\s*if not is_({remove_pattern})_available\(\):\s*$")
- re_try = re.compile(r"\s*try:")
- re_else = re.compile(r"\s*else:")
- re_is_xxx_available = re.compile(rf"is_({remove_pattern})_available")
-
- with open(init_file, "r", encoding="utf-8") as f:
- content = f.read()
-
- lines = content.split("\n")
- new_lines = []
- idx = 0
- while idx < len(lines):
- # Conditional imports in try-except-else blocks
- if (re_conditional_imports.search(lines[idx]) is not None) and (re_try.search(lines[idx - 1]) is not None):
- # Remove the preceding `try:`
- new_lines.pop()
- idx += 1
- # Iterate until `else:`
- while is_empty_line(lines[idx]) or re_else.search(lines[idx]) is None:
- idx += 1
- idx += 1
- indent = find_indent(lines[idx])
- while find_indent(lines[idx]) >= indent or is_empty_line(lines[idx]):
- idx += 1
- # Remove the import from utils
- elif re_is_xxx_available.search(lines[idx]) is not None:
- line = lines[idx]
- for framework in to_remove:
- line = line.replace(f", is_{framework}_available", "")
- line = line.replace(f"is_{framework}_available, ", "")
- line = line.replace(f"is_{framework}_available,", "")
- line = line.replace(f"is_{framework}_available", "")
-
- if len(line.strip()) > 0:
- new_lines.append(line)
- idx += 1
- # Otherwise we keep the line, except if it's a tokenizer import and we don't want to keep it.
- elif keep_processing or (
- re.search(r'^\s*"(tokenization|processing|feature_extraction|image_processing)', lines[idx]) is None
- and re.search(r"^\s*from .(tokenization|processing|feature_extraction|image_processing)", lines[idx])
- is None
- ):
- new_lines.append(lines[idx])
- idx += 1
- else:
- idx += 1
-
- with open(init_file, "w", encoding="utf-8") as f:
- f.write("\n".join(new_lines))
-
-
-def add_model_to_main_init(
- old_model_patterns: ModelPatterns,
- new_model_patterns: ModelPatterns,
- frameworks: Optional[list[str]] = None,
- with_processing: bool = True,
-):
- """
- Add a model to the main init of Transformers.
-
- Args:
- old_model_patterns (`ModelPatterns`): The patterns for the old model.
- new_model_patterns (`ModelPatterns`): The patterns for the new model.
- frameworks (`list[str]`, *optional*):
- If specified, only the models implemented in those frameworks will be added.
- with_processing (`bool`, *optional*, defaults to `True`):
- Whether the tokenizer/feature extractor/processor of the model should also be added to the init or not.
- """
- with open(TRANSFORMERS_PATH / "__init__.py", "r", encoding="utf-8") as f:
- content = f.read()
-
- lines = content.split("\n")
- idx = 0
- new_lines = []
- framework = None
- while idx < len(lines):
- new_framework = False
- if not is_empty_line(lines[idx]) and find_indent(lines[idx]) == 0:
- framework = None
- elif lines[idx].lstrip().startswith("if not is_torch_available"):
- framework = "pt"
- new_framework = True
- elif lines[idx].lstrip().startswith("if not is_tf_available"):
- framework = "tf"
- new_framework = True
- elif lines[idx].lstrip().startswith("if not is_flax_available"):
- framework = "flax"
- new_framework = True
-
- if new_framework:
- # For a new framework, we need to skip until the else: block to get where the imports are.
- while lines[idx].strip() != "else:":
- new_lines.append(lines[idx])
- idx += 1
-
- # Skip if we are in a framework not wanted.
- if framework is not None and frameworks is not None and framework not in frameworks:
- new_lines.append(lines[idx])
- idx += 1
- elif re.search(rf'models.{old_model_patterns.model_lower_cased}( |")', lines[idx]) is not None:
- block = [lines[idx]]
- indent = find_indent(lines[idx])
- idx += 1
- while find_indent(lines[idx]) > indent:
- block.append(lines[idx])
- idx += 1
- if lines[idx].strip() in [")", "]", "],"]:
- block.append(lines[idx])
- idx += 1
- block = "\n".join(block)
- new_lines.append(block)
-
- add_block = True
- if not with_processing:
- processing_classes = [
- old_model_patterns.tokenizer_class,
- old_model_patterns.image_processor_class,
- old_model_patterns.image_processor_fast_class,
- old_model_patterns.feature_extractor_class,
- old_model_patterns.processor_class,
- ]
- # Only keep the ones that are not None
- processing_classes = [c for c in processing_classes if c is not None]
- for processing_class in processing_classes:
- block = block.replace(f' "{processing_class}",', "")
- block = block.replace(f', "{processing_class}"', "")
- block = block.replace(f" {processing_class},", "")
- block = block.replace(f", {processing_class}", "")
-
- if processing_class in block:
- add_block = False
- if add_block:
- new_lines.append(replace_model_patterns(block, old_model_patterns, new_model_patterns)[0])
- else:
- new_lines.append(lines[idx])
- idx += 1
-
- with open(TRANSFORMERS_PATH / "__init__.py", "w", encoding="utf-8") as f:
- f.write("\n".join(new_lines))
-
-
-def insert_tokenizer_in_auto_module(old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns):
- """
- Add a tokenizer to the relevant mappings in the auto module.
-
- Args:
- old_model_patterns (`ModelPatterns`): The patterns for the old model.
- new_model_patterns (`ModelPatterns`): The patterns for the new model.
- """
- if old_model_patterns.tokenizer_class is None or new_model_patterns.tokenizer_class is None:
- return
-
- with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "r", encoding="utf-8") as f:
- content = f.read()
-
- pattern_tokenizer = re.compile(r"^\s*TOKENIZER_MAPPING_NAMES\s*=\s*OrderedDict\b")
- lines = content.split("\n")
- idx = 0
- # First we get to the TOKENIZER_MAPPING_NAMES block.
- while not pattern_tokenizer.search(lines[idx]):
- idx += 1
- idx += 1
-
- # That block will end at this prompt:
- while not lines[idx].startswith("TOKENIZER_MAPPING = _LazyAutoMapping"):
- # Either all the tokenizer block is defined on one line, in which case, it ends with "),"
- if lines[idx].endswith(","):
- block = lines[idx]
- # Otherwise it takes several lines until we get to a "),"
- else:
- block = []
- # should change to " )," instead of " ),"
- while not lines[idx].startswith(" ),"):
- block.append(lines[idx])
- idx += 1
- # if the lines[idx] does start with " )," we still need it in our block
- block.append(lines[idx])
- block = "\n".join(block)
- idx += 1
-
- # If we find the model type and tokenizer class in that block, we have the old model tokenizer block
- if f'"{old_model_patterns.model_type}"' in block and old_model_patterns.tokenizer_class in block:
- break
-
- new_block = block.replace(old_model_patterns.model_type, new_model_patterns.model_type)
- new_block = new_block.replace(old_model_patterns.tokenizer_class, new_model_patterns.tokenizer_class)
-
- new_lines = lines[:idx] + [new_block] + lines[idx:]
- with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "w", encoding="utf-8") as f:
- f.write("\n".join(new_lines))
-
-
-AUTO_CLASSES_PATTERNS = {
- "configuration_auto.py": [
- ' ("{model_type}", "{model_name}"),',
- ' ("{model_type}", "{config_class}"),',
- ' ("{model_type}", "{pretrained_archive_map}"),',
- ],
- "feature_extraction_auto.py": [' ("{model_type}", "{feature_extractor_class}"),'],
- "image_processing_auto.py": [' ("{model_type}", "{image_processor_classes}"),'],
- "modeling_auto.py": [' ("{model_type}", "{any_pt_class}"),'],
- "modeling_tf_auto.py": [' ("{model_type}", "{any_tf_class}"),'],
- "modeling_flax_auto.py": [' ("{model_type}", "{any_flax_class}"),'],
- "processing_auto.py": [' ("{model_type}", "{processor_class}"),'],
-}
-
-
-def add_model_to_auto_classes(
- old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns, model_classes: dict[str, list[str]]
-):
- """
- Add a model to the relevant mappings in the auto module.
-
- Args:
- old_model_patterns (`ModelPatterns`): The patterns for the old model.
- new_model_patterns (`ModelPatterns`): The patterns for the new model.
- model_classes (`dict[str, list[str]]`): A dictionary framework to list of model classes implemented.
- """
- for filename, patterns in AUTO_CLASSES_PATTERNS.items():
- # Extend patterns with all model classes if necessary
- new_patterns = []
- for pattern in patterns:
- if re.search("any_([a-z]*)_class", pattern) is not None:
- framework = re.search("any_([a-z]*)_class", pattern).groups()[0]
- if framework in model_classes:
- new_patterns.extend(
- [
- pattern.replace("{" + f"any_{framework}_class" + "}", cls)
- for cls in model_classes[framework]
- ]
- )
- elif "{config_class}" in pattern:
- new_patterns.append(pattern.replace("{config_class}", old_model_patterns.config_class))
- elif "{image_processor_classes}" in pattern:
- if (
- old_model_patterns.image_processor_class is not None
- and new_model_patterns.image_processor_class is not None
- ):
- if (
- old_model_patterns.image_processor_fast_class is not None
- and new_model_patterns.image_processor_fast_class is not None
- ):
- new_patterns.append(
- pattern.replace(
- '"{image_processor_classes}"',
- f'("{old_model_patterns.image_processor_class}", "{old_model_patterns.image_processor_fast_class}")',
- )
- )
- else:
- new_patterns.append(
- pattern.replace(
- '"{image_processor_classes}"', f'("{old_model_patterns.image_processor_class}",)'
- )
- )
- elif "{feature_extractor_class}" in pattern:
- if (
- old_model_patterns.feature_extractor_class is not None
- and new_model_patterns.feature_extractor_class is not None
- ):
- new_patterns.append(
- pattern.replace("{feature_extractor_class}", old_model_patterns.feature_extractor_class)
- )
- elif "{processor_class}" in pattern:
- if old_model_patterns.processor_class is not None and new_model_patterns.processor_class is not None:
- new_patterns.append(pattern.replace("{processor_class}", old_model_patterns.processor_class))
- else:
- new_patterns.append(pattern)
-
- # Loop through all patterns.
- for pattern in new_patterns:
- full_name = TRANSFORMERS_PATH / "models" / "auto" / filename
- old_model_line = pattern
- new_model_line = pattern
- for attr in ["model_type", "model_name"]:
- old_model_line = old_model_line.replace("{" + attr + "}", getattr(old_model_patterns, attr))
- new_model_line = new_model_line.replace("{" + attr + "}", getattr(new_model_patterns, attr))
- new_model_line = new_model_line.replace(
- old_model_patterns.model_camel_cased, new_model_patterns.model_camel_cased
+ for filename, to_add in corrected_filenames_to_add:
+ if to_add:
+ # The auto mapping
+ filename = filename.replace("_fast.py", ".py")
+ with open(TRANSFORMERS_PATH / "models" / "auto" / filename) as f:
+ file = f.read()
+ # The regex has to be a bit complex like this as the tokenizer mapping has new lines everywhere
+ matching_lines = re.findall(
+ rf'( {{8,12}}\(\s*"{old_lowercase_name}",.*?\),\n)(?: {{4,12}}\(|\])', file, re.DOTALL
)
- add_content_to_file(full_name, new_model_line, add_after=old_model_line)
-
- # Tokenizers require special handling
- insert_tokenizer_in_auto_module(old_model_patterns, new_model_patterns)
+ for match in matching_lines:
+ add_content_to_file(
+ TRANSFORMERS_PATH / "models" / "auto" / filename,
+ new_content=match.replace(old_lowercase_name, new_lowercase_name).replace(
+ old_cased_name, new_cased_name
+ ),
+ add_after=match,
+ )
-DOC_OVERVIEW_TEMPLATE = """## Overview
-
-The {model_name} model was proposed in []() by .
-
-
-The abstract from the paper is the following:
-
-**
-
-Tips:
-
-
-
-This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/).
-The original code can be found [here]().
-
-"""
-
-
-def duplicate_doc_file(
- doc_file: Union[str, os.PathLike],
- old_model_patterns: ModelPatterns,
- new_model_patterns: ModelPatterns,
- dest_file: Optional[Union[str, os.PathLike]] = None,
- frameworks: Optional[list[str]] = None,
-):
+def create_doc_file(new_paper_name: str, public_classes: list[str]):
"""
- Duplicate a documentation file and adapts it for a new model.
+ Create a new doc file to fill for the new model.
Args:
- module_file (`str` or `os.PathLike`): Path to the doc file to duplicate.
- old_model_patterns (`ModelPatterns`): The patterns for the old model.
- new_model_patterns (`ModelPatterns`): The patterns for the new model.
- dest_file (`str` or `os.PathLike`, *optional*): Path to the new doc file.
- Will default to the a file named `{new_model_patterns.model_type}.md` in the same folder as `module_file`.
- frameworks (`list[str]`, *optional*):
- If passed, will only keep the model classes corresponding to this list of frameworks in the new doc file.
+ new_paper_name (`str`):
+ The fully cased name (as in the official paper name) of the new model.
+ public_classes (`list[str]`):
+ A list of all the public classes that the model will have in the library.
"""
- with open(doc_file, "r", encoding="utf-8") as f:
- content = f.read()
+ added_note = (
+ "\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that "
+ "may not be rendered properly in your Markdown viewer.\n\n-->\n\n"
+ )
+ copyright_for_markdown = re.sub(r"# ?", "", COPYRIGHT).replace("coding=utf-8\n", "
+
+
+ # MyTest
+
+ ## Overview
+
+ The MyTest model was proposed in []() by .
+
+
+ The abstract from the paper is the following:
+
+
+
+ Tips:
+
+
+
+ This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/).
+ The original code can be found [here]().
+
+ ## Usage examples
+
+
+
+ ## MyTestConfig
+
+ [[autodoc]] MyTestConfig
+
+ ## MyTestForCausalLM
+
+ [[autodoc]] MyTestForCausalLM
+
+ ## MyTestModel
+
+ [[autodoc]] MyTestModel
+ - forward
+
+ ## MyTestPreTrainedModel
+
+ [[autodoc]] MyTestPreTrainedModel
+ - forward
+
+ ## MyTestForSequenceClassification
+
+ [[autodoc]] MyTestForSequenceClassification
+
+ ## MyTestForQuestionAnswering
+
+ [[autodoc]] MyTestForQuestionAnswering
+
+ ## MyTestForTokenClassification
+
+ [[autodoc]] MyTestForTokenClassification
+ """
+ )
+ self.assertFileIsEqual(EXPECTED_DOC, os.path.join(self.DOC_PATH, "model_doc", "my_test.md"))
+
+ def test_phi4_with_all_processors(self):
+ # This is the structure without adding the tokenizers
+ filenames_to_add = (
+ ("configuration_phi4_multimodal.py", True),
+ ("modeling_phi4_multimodal.py", True),
+ ("tokenization_phi4_multimodal.py", False),
+ ("tokenization_phi4_multimodal_fast.py", False),
+ ("image_processing_phi4_multimodal.py", False),
+ ("image_processing_phi4_multimodal_fast.py", True),
+ ("video_processing_phi4_multimodal.py", False),
+ ("feature_extraction_phi4_multimodal.py", True),
+ ("processing_phi4_multimodal.py", True),
+ )
+ # Run the command
+ create_new_model_like(
+ old_model_infos=ModelInfos("phi4_multimodal"),
+ new_lowercase_name="my_test2",
+ new_model_paper_name="MyTest2",
+ filenames_to_add=filenames_to_add,
+ create_fast_image_processor=False,
+ )
+
+ # First assert that all files were created correctly
+ model_repo = os.path.join(self.MODEL_PATH, "my_test2")
+ tests_repo = os.path.join(self.TESTS_MODEL_PATH, "my_test2")
+ self.assertTrue(os.path.isfile(os.path.join(model_repo, "modular_my_test2.py")))
+ self.assertTrue(os.path.isfile(os.path.join(model_repo, "modeling_my_test2.py")))
+ self.assertTrue(os.path.isfile(os.path.join(model_repo, "configuration_my_test2.py")))
+ self.assertTrue(os.path.isfile(os.path.join(model_repo, "image_processing_my_test2_fast.py")))
+ self.assertTrue(os.path.isfile(os.path.join(model_repo, "feature_extraction_my_test2.py")))
+ self.assertTrue(os.path.isfile(os.path.join(model_repo, "processing_my_test2.py")))
+ self.assertTrue(os.path.isfile(os.path.join(model_repo, "__init__.py")))
+ self.assertTrue(os.path.isfile(os.path.join(self.DOC_PATH, "model_doc", "my_test2.md")))
+ self.assertTrue(os.path.isfile(os.path.join(tests_repo, "__init__.py")))
+ self.assertTrue(os.path.isfile(os.path.join(tests_repo, "test_modeling_my_test2.py")))
+ self.assertTrue(os.path.isfile(os.path.join(tests_repo, "test_feature_extraction_my_test2.py")))
+ self.assertTrue(os.path.isfile(os.path.join(tests_repo, "test_image_processing_my_test2.py")))
+
+ # Now assert the correct imports/auto mappings/toctree were added
+ self.assertInFile(
+ "from .my_test2 import *\n",
+ os.path.join(self.MODEL_PATH, "__init__.py"),
+ )
+ self.assertInFile(
+ '("my_test2", "MyTest2Config"),\n',
+ os.path.join(self.MODEL_PATH, "auto", "configuration_auto.py"),
+ )
+ self.assertInFile(
+ '("my_test2", "MyTest2"),\n',
+ os.path.join(self.MODEL_PATH, "auto", "configuration_auto.py"),
+ )
+ self.assertInFile(
+ '("my_test2", "MyTest2Model"),\n',
+ os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
+ )
+ self.assertInFile(
+ '("my_test2", "MyTest2ForCausalLM"),\n',
+ os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
+ )
+ self.assertInFile(
+ '("my_test2", (None, "MyTest2ImageProcessorFast")),\n',
+ os.path.join(self.MODEL_PATH, "auto", "image_processing_auto.py"),
+ )
+ self.assertInFile(
+ '("my_test2", "MyTest2FeatureExtractor"),\n',
+ os.path.join(self.MODEL_PATH, "auto", "feature_extraction_auto.py"),
+ )
+ self.assertInFile(
+ '("my_test2", "MyTest2Processor"),\n',
+ os.path.join(self.MODEL_PATH, "auto", "processing_auto.py"),
+ )
+ self.assertInFile(
+ "- local: model_doc/my_test2\n title: MyTest2\n",
+ os.path.join(self.DOC_PATH, "_toctree.yml"),
+ )
+
+ # Check some exact file creation. For model definition, only check modular as modeling/config/etc... are created
+ # directly from it
+ EXPECTED_MODULAR = textwrap.dedent(
+ f"""
+ # coding=utf-8
+ # Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
+ #
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #
+ # http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS,
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ # See the License for the specific language governing permissions and
+ # limitations under the License.
+
+ from ..phi4_multimodal.configuration_phi4_multimodal import (
+ Phi4MultimodalAudioConfig,
+ Phi4MultimodalConfig,
+ Phi4MultimodalVisionConfig,
)
- self.init_file(doc_file, test_doc)
- duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns, frameworks=["pt"])
- self.check_result(new_doc_file, test_new_doc_pt_only)
-
- test_new_doc_no_tok = test_new_doc.replace(
- """
-## GPTNewNewTokenizer
-
-[[autodoc]] GPTNewNewTokenizer
- - save_vocabulary
-
-## GPTNewNewTokenizerFast
-
-[[autodoc]] GPTNewNewTokenizerFast
-""",
- "",
+ from ..phi4_multimodal.feature_extraction_phi4_multimodal import Phi4MultimodalFeatureExtractor
+ from ..phi4_multimodal.image_processing_phi4_multimodal_fast import (
+ Phi4MultimodalFastImageProcessorKwargs,
+ Phi4MultimodalImageProcessorFast,
)
- new_model_patterns = ModelPatterns(
- "GPT-New New", "huggingface/gpt-new-new", tokenizer_class="GPT2Tokenizer"
+ from ..phi4_multimodal.modeling_phi4_multimodal import (
+ Phi4MultimodalAttention,
+ Phi4MultimodalAudioAttention,
+ Phi4MultimodalAudioConformerEncoderLayer,
+ Phi4MultimodalAudioConvModule,
+ Phi4MultimodalAudioDepthWiseSeperableConv1d,
+ Phi4MultimodalAudioEmbedding,
+ Phi4MultimodalAudioGluPointWiseConv,
+ Phi4MultimodalAudioMeanVarianceNormLayer,
+ Phi4MultimodalAudioMLP,
+ Phi4MultimodalAudioModel,
+ Phi4MultimodalAudioNemoConvSubsampling,
+ Phi4MultimodalAudioPreTrainedModel,
+ Phi4MultimodalAudioRelativeAttentionBias,
+ Phi4MultimodalDecoderLayer,
+ Phi4MultimodalFeatureEmbedding,
+ Phi4MultimodalForCausalLM,
+ Phi4MultimodalImageEmbedding,
+ Phi4MultimodalMLP,
+ Phi4MultimodalModel,
+ Phi4MultimodalPreTrainedModel,
+ Phi4MultimodalRMSNorm,
+ Phi4MultimodalRotaryEmbedding,
+ Phi4MultimodalVisionAttention,
+ Phi4MultimodalVisionEmbeddings,
+ Phi4MultimodalVisionEncoder,
+ Phi4MultimodalVisionEncoderLayer,
+ Phi4MultimodalVisionMLP,
+ Phi4MultimodalVisionModel,
+ Phi4MultimodalVisionMultiheadAttentionPoolingHead,
+ Phi4MultimodalVisionPreTrainedModel,
)
- self.init_file(doc_file, test_doc)
- duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns, frameworks=["pt", "tf", "flax"])
- print(test_new_doc_no_tok)
- self.check_result(new_doc_file, test_new_doc_no_tok)
+ from ..phi4_multimodal.processing_phi4_multimodal import Phi4MultimodalProcessor, Phi4MultimodalProcessorKwargs
- test_new_doc_pt_only_no_tok = test_new_doc_no_tok.replace(
- """
-## TFGPTNewNewModel
-[[autodoc]] TFGPTNewNewModel
- - call
+ class MyTest2VisionConfig(Phi4MultimodalVisionConfig):
+ pass
-## FlaxGPTNewNewModel
-[[autodoc]] FlaxGPTNewNewModel
- - __call__
+ class MyTest2AudioConfig(Phi4MultimodalAudioConfig):
+ pass
-""",
- "",
- )
- self.init_file(doc_file, test_doc)
- duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns, frameworks=["pt"])
- self.check_result(new_doc_file, test_new_doc_pt_only_no_tok)
+
+ class MyTest2Config(Phi4MultimodalConfig):
+ pass
+
+
+ class MyTest2VisionMLP(Phi4MultimodalVisionMLP):
+ pass
+
+
+ class MyTest2VisionAttention(Phi4MultimodalVisionAttention):
+ pass
+
+
+ class MyTest2VisionEncoderLayer(Phi4MultimodalVisionEncoderLayer):
+ pass
+
+
+ class MyTest2VisionEncoder(Phi4MultimodalVisionEncoder):
+ pass
+
+
+ class MyTest2VisionPreTrainedModel(Phi4MultimodalVisionPreTrainedModel):
+ pass
+
+
+ class MyTest2VisionEmbeddings(Phi4MultimodalVisionEmbeddings):
+ pass
+
+
+ class MyTest2VisionMultiheadAttentionPoolingHead(Phi4MultimodalVisionMultiheadAttentionPoolingHead):
+ pass
+
+
+ class MyTest2VisionModel(Phi4MultimodalVisionModel):
+ pass
+
+
+ class MyTest2ImageEmbedding(Phi4MultimodalImageEmbedding):
+ pass
+
+
+ class MyTest2AudioMLP(Phi4MultimodalAudioMLP):
+ pass
+
+
+ class MyTest2AudioAttention(Phi4MultimodalAudioAttention):
+ pass
+
+
+ class MyTest2AudioDepthWiseSeperableConv1d(Phi4MultimodalAudioDepthWiseSeperableConv1d):
+ pass
+
+
+ class MyTest2AudioGluPointWiseConv(Phi4MultimodalAudioGluPointWiseConv):
+ pass
+
+
+ class MyTest2AudioConvModule(Phi4MultimodalAudioConvModule):
+ pass
+
+
+ class MyTest2AudioConformerEncoderLayer(Phi4MultimodalAudioConformerEncoderLayer):
+ pass
+
+
+ class MyTest2AudioNemoConvSubsampling(Phi4MultimodalAudioNemoConvSubsampling):
+ pass
+
+
+ class MyTest2AudioRelativeAttentionBias(Phi4MultimodalAudioRelativeAttentionBias):
+ pass
+
+
+ class MyTest2AudioMeanVarianceNormLayer(Phi4MultimodalAudioMeanVarianceNormLayer):
+ pass
+
+
+ class MyTest2AudioPreTrainedModel(Phi4MultimodalAudioPreTrainedModel):
+ pass
+
+
+ class MyTest2AudioModel(Phi4MultimodalAudioModel):
+ pass
+
+
+ class MyTest2AudioEmbedding(Phi4MultimodalAudioEmbedding):
+ pass
+
+
+ class MyTest2RMSNorm(Phi4MultimodalRMSNorm):
+ pass
+
+
+ class MyTest2MLP(Phi4MultimodalMLP):
+ pass
+
+
+ class MyTest2Attention(Phi4MultimodalAttention):
+ pass
+
+
+ class MyTest2DecoderLayer(Phi4MultimodalDecoderLayer):
+ pass
+
+
+ class MyTest2FeatureEmbedding(Phi4MultimodalFeatureEmbedding):
+ pass
+
+
+ class MyTest2RotaryEmbedding(Phi4MultimodalRotaryEmbedding):
+ pass
+
+
+ class MyTest2PreTrainedModel(Phi4MultimodalPreTrainedModel):
+ pass
+
+
+ class MyTest2Model(Phi4MultimodalModel):
+ pass
+
+
+ class MyTest2ForCausalLM(Phi4MultimodalForCausalLM):
+ pass
+
+
+ class MyTest2FastImageProcessorKwargs(Phi4MultimodalFastImageProcessorKwargs):
+ pass
+
+
+ class MyTest2ImageProcessorFast(Phi4MultimodalImageProcessorFast):
+ pass
+
+
+ class MyTest2FeatureExtractor(Phi4MultimodalFeatureExtractor):
+ pass
+
+
+ class MyTest2ProcessorKwargs(Phi4MultimodalProcessorKwargs):
+ pass
+
+
+ class MyTest2Processor(Phi4MultimodalProcessor):
+ pass
+
+
+ __all__ = [
+ "MyTest2VisionConfig",
+ "MyTest2AudioConfig",
+ "MyTest2Config",
+ "MyTest2AudioPreTrainedModel",
+ "MyTest2AudioModel",
+ "MyTest2VisionPreTrainedModel",
+ "MyTest2VisionModel",
+ "MyTest2PreTrainedModel",
+ "MyTest2Model",
+ "MyTest2ForCausalLM",
+ "MyTest2ImageProcessorFast",
+ "MyTest2FeatureExtractor",
+ "MyTest2Processor",
+ ]
+ """
+ )
+ self.assertFileIsEqual(EXPECTED_MODULAR, os.path.join(model_repo, "modular_my_test2.py"))
+
+ EXPECTED_INIT = textwrap.dedent(
+ f"""
+ # coding=utf-8
+ # Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
+ #
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #
+ # http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS,
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ # See the License for the specific language governing permissions and
+ # limitations under the License.
+
+ from typing import TYPE_CHECKING
+
+ from ...utils import _LazyModule
+ from ...utils.import_utils import define_import_structure
+
+
+ if TYPE_CHECKING:
+ from .configuration_my_test2 import *
+ from .feature_extraction_my_test2 import *
+ from .image_processing_my_test2_fast import *
+ from .modeling_my_test2 import *
+ from .processing_my_test2 import *
+ else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ """
+ )
+ self.assertFileIsEqual(EXPECTED_INIT, os.path.join(model_repo, "__init__.py"))
+
+ EXPECTED_DOC = textwrap.dedent(
+ f"""
+
+
+
+ # MyTest2
+
+ ## Overview
+
+ The MyTest2 model was proposed in []() by .
+
+
+ The abstract from the paper is the following:
+
+
+
+ Tips:
+
+
+
+ This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/).
+ The original code can be found [here]().
+
+ ## Usage examples
+
+
+
+ ## MyTest2VisionConfig
+
+ [[autodoc]] MyTest2VisionConfig
+
+ ## MyTest2AudioConfig
+
+ [[autodoc]] MyTest2AudioConfig
+
+ ## MyTest2Config
+
+ [[autodoc]] MyTest2Config
+
+ ## MyTest2AudioPreTrainedModel
+
+ [[autodoc]] MyTest2AudioPreTrainedModel
+ - forward
+
+ ## MyTest2AudioModel
+
+ [[autodoc]] MyTest2AudioModel
+ - forward
+
+ ## MyTest2VisionPreTrainedModel
+
+ [[autodoc]] MyTest2VisionPreTrainedModel
+ - forward
+
+ ## MyTest2VisionModel
+
+ [[autodoc]] MyTest2VisionModel
+ - forward
+
+ ## MyTest2PreTrainedModel
+
+ [[autodoc]] MyTest2PreTrainedModel
+ - forward
+
+ ## MyTest2Model
+
+ [[autodoc]] MyTest2Model
+ - forward
+
+ ## MyTest2ForCausalLM
+
+ [[autodoc]] MyTest2ForCausalLM
+
+ ## MyTest2ImageProcessorFast
+
+ [[autodoc]] MyTest2ImageProcessorFast
+
+ ## MyTest2FeatureExtractor
+
+ [[autodoc]] MyTest2FeatureExtractor
+
+ ## MyTest2Processor
+
+ [[autodoc]] MyTest2Processor
+ """
+ )
+ self.assertFileIsEqual(EXPECTED_DOC, os.path.join(self.DOC_PATH, "model_doc", "my_test2.md"))
diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py
index ade3589868..fd553cc3b9 100644
--- a/utils/modular_model_converter.py
+++ b/utils/modular_model_converter.py
@@ -17,12 +17,12 @@ import glob
import importlib
import os
import re
+import subprocess
from abc import ABC, abstractmethod
from collections import Counter, defaultdict, deque
from typing import Optional, Union
import libcst as cst
-from check_copies import run_ruff
from create_dependency_mapping import find_priority_list
from libcst import ClassDef, CSTVisitor
from libcst import matchers as m
@@ -1676,6 +1676,16 @@ def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]:
return files
+def run_ruff(code, check=False):
+ if check:
+ command = ["ruff", "check", "-", "--fix", "--exit-zero"]
+ else:
+ command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
+ process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
+ stdout, _ = process.communicate(input=code.encode())
+ return stdout.decode()
+
+
def convert_modular_file(modular_file):
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
output = {}