More utils doc (#25457)
* Document and clean more utils. * More documentation and fixes * Switch to Lysandre's token * Address review comments * Actually put else
This commit is contained in:
@@ -12,11 +12,30 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that checks the big table in the file docs/source/en/index.md and potentially updates it.
|
||||
|
||||
Use from the root of the repo with:
|
||||
|
||||
```bash
|
||||
python utils/check_inits.py
|
||||
```
|
||||
|
||||
for a check that will error in case of inconsistencies (used by `make repo-consistency`).
|
||||
|
||||
To auto-fix issues run:
|
||||
|
||||
```bash
|
||||
python utils/check_inits.py --fix_and_overwrite
|
||||
```
|
||||
|
||||
which is used by `make fix-copies`.
|
||||
"""
|
||||
import argparse
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from transformers.utils import direct_transformers_import
|
||||
|
||||
@@ -28,19 +47,28 @@ PATH_TO_DOCS = "docs/source/en"
|
||||
REPO_PATH = "."
|
||||
|
||||
|
||||
def _find_text_in_file(filename, start_prompt, end_prompt):
|
||||
def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> str:
|
||||
"""
|
||||
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
|
||||
lines.
|
||||
Find the text in filename between two prompts.
|
||||
|
||||
Args:
|
||||
filename (`str`): The file to search into.
|
||||
start_prompt (`str`): A string to look for at the start of the content searched.
|
||||
end_prompt (`str`): A string that will mark the end of the content to look for.
|
||||
|
||||
Returns:
|
||||
`str`: The content between the prompts.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Find the start prompt.
|
||||
start_index = 0
|
||||
while not lines[start_index].startswith(start_prompt):
|
||||
start_index += 1
|
||||
start_index += 1
|
||||
|
||||
# Now go until the end prompt.
|
||||
end_index = start_index
|
||||
while not lines[end_index].startswith(end_prompt):
|
||||
end_index += 1
|
||||
@@ -54,12 +82,10 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
|
||||
return "".join(lines[start_index:end_index]), start_index, end_index, lines
|
||||
|
||||
|
||||
# Add here suffixes that are used to identify models, separated by |
|
||||
ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration"
|
||||
# Regexes that match TF/Flax/PT model names.
|
||||
# Regexes that match TF/Flax/PT model names. Add here suffixes that are used to identify models, separated by |
|
||||
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||
# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes.
|
||||
# Will match any TF or Flax model too so need to be in an else branch after the two previous regexes.
|
||||
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||
|
||||
|
||||
@@ -67,22 +93,49 @@ _re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGenerati
|
||||
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
|
||||
|
||||
|
||||
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
||||
def camel_case_split(identifier):
|
||||
"Split a camelcased `identifier` into words."
|
||||
def camel_case_split(identifier: str) -> List[str]:
|
||||
"""
|
||||
Split a camel-cased name into words.
|
||||
|
||||
Args:
|
||||
identifier (`str`): The camel-cased name to parse.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of words in the identifier (as seprated by capital letters).
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> camel_case_split("CamelCasedClass")
|
||||
["Camel", "Cased", "Class"]
|
||||
```
|
||||
"""
|
||||
# Regex thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
||||
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
|
||||
return [m.group(0) for m in matches]
|
||||
|
||||
|
||||
def _center_text(text, width):
|
||||
def _center_text(text: str, width: int) -> str:
|
||||
"""
|
||||
Utility that will add spaces on the left and right of a text to make it centered for a given width.
|
||||
|
||||
Args:
|
||||
text (`str`): The text to center.
|
||||
width (`int`): The desired length of the result.
|
||||
|
||||
Returns:
|
||||
`str`: A text of length `width` with the original `text` in the middle.
|
||||
"""
|
||||
text_length = 2 if text == "✅" or text == "❌" else len(text)
|
||||
left_indent = (width - text_length) // 2
|
||||
right_indent = width - text_length - left_indent
|
||||
return " " * left_indent + text + " " * right_indent
|
||||
|
||||
|
||||
def get_model_table_from_auto_modules():
|
||||
"""Generates an up-to-date model table from the content of the auto modules."""
|
||||
def get_model_table_from_auto_modules() -> str:
|
||||
"""
|
||||
Generates an up-to-date model table from the content of the auto modules.
|
||||
"""
|
||||
# Dictionary model names to config.
|
||||
config_maping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
|
||||
model_name_to_config = {
|
||||
@@ -92,7 +145,7 @@ def get_model_table_from_auto_modules():
|
||||
}
|
||||
model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()}
|
||||
|
||||
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
|
||||
# Dictionaries flagging if each model prefix has a backend in PT/TF/Flax.
|
||||
pt_models = collections.defaultdict(bool)
|
||||
tf_models = collections.defaultdict(bool)
|
||||
flax_models = collections.defaultdict(bool)
|
||||
@@ -145,7 +198,13 @@ def get_model_table_from_auto_modules():
|
||||
|
||||
|
||||
def check_model_table(overwrite=False):
|
||||
"""Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`."""
|
||||
"""
|
||||
Check the model table in the index.md is consistent with the state of the lib and potentially fix it.
|
||||
|
||||
Args:
|
||||
overwrite (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to overwrite the table when it's not up to date.
|
||||
"""
|
||||
current_table, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(PATH_TO_DOCS, "index.md"),
|
||||
start_prompt="<!--This table is updated automatically from the auto modules",
|
||||
|
||||
Reference in New Issue
Block a user