More utils doc (#25457)
* Document and clean more utils. * More documentation and fixes * Switch to Lysandre's token * Address review comments * Actually put else
This commit is contained in:
2
.github/workflows/update_metdata.yml
vendored
2
.github/workflows/update_metdata.yml
vendored
@@ -24,4 +24,4 @@ jobs:
|
|||||||
|
|
||||||
- name: Update metadata
|
- name: Update metadata
|
||||||
run: |
|
run: |
|
||||||
python utils/update_metadata.py --token ${{ secrets.SYLVAIN_HF_TOKEN }} --commit_sha ${{ github.sha }}
|
python utils/update_metadata.py --token ${{ secrets.LYSANDRE_HF_TOKEN }} --commit_sha ${{ github.sha }}
|
||||||
|
|||||||
28
setup.py
28
setup.py
@@ -17,25 +17,26 @@ Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/m
|
|||||||
|
|
||||||
To create the package for pypi.
|
To create the package for pypi.
|
||||||
|
|
||||||
1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
|
1. Create the release branch named: v<RELEASE>-release, for example v4.19-release. For a patch release checkout the
|
||||||
documentation.
|
current release branch.
|
||||||
|
|
||||||
If releasing on a special branch, copy the updated README.md on the main branch for your the commit you will make
|
If releasing on a special branch, copy the updated README.md on the main branch for your the commit you will make
|
||||||
for the post-release and run `make fix-copies` on the main branch as well.
|
for the post-release and run `make fix-copies` on the main branch as well.
|
||||||
|
|
||||||
2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid.
|
2. Run `make pre-release` (or `make pre-patch` for a patch release) and commit these changes with the message:
|
||||||
|
"Release: <VERSION>" and push.
|
||||||
|
|
||||||
3. Unpin specific versions from setup.py that use a git install.
|
3. Go back to the main branch and run `make post-release` then `make fix-copies`. Commit these changes with the
|
||||||
|
message "v<NEXT_VERSION>.dev.0" and push to main.
|
||||||
|
|
||||||
4. Checkout the release branch (v<RELEASE>-release, for example v4.19-release), and commit these changes with the
|
# If you were just cutting the branch in preparation for a release, you can stop here for now.
|
||||||
message: "Release: <VERSION>" and push.
|
|
||||||
|
|
||||||
5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs)
|
4. Wait for the tests on the release branch to be completed and be green (otherwise revert and fix bugs)
|
||||||
|
|
||||||
6. Add a tag in git to mark the release: "git tag v<VERSION> -m 'Adds tag v<VERSION> for pypi' "
|
5. On the release branch, add a tag in git to mark the release: "git tag v<VERSION> -m 'Adds tag v<VERSION> for pypi' "
|
||||||
Push the tag to git: git push --tags origin v<RELEASE>-release
|
Push the tag to git: git push --tags origin v<RELEASE>-release
|
||||||
|
|
||||||
7. Build both the sources and the wheel. Do not change anything in setup.py between
|
6. Build both the sources and the wheel. Do not change anything in setup.py between
|
||||||
creating the wheel and the source distribution (obviously).
|
creating the wheel and the source distribution (obviously).
|
||||||
|
|
||||||
Run `make build-release`. This will build the release and do some sanity checks for you. If this ends with an error
|
Run `make build-release`. This will build the release and do some sanity checks for you. If this ends with an error
|
||||||
@@ -43,7 +44,7 @@ To create the package for pypi.
|
|||||||
|
|
||||||
You should now have a /dist directory with both .whl and .tar.gz source versions.
|
You should now have a /dist directory with both .whl and .tar.gz source versions.
|
||||||
|
|
||||||
8. Check that everything looks correct by uploading the package to the pypi test server:
|
7. Check that everything looks correct by uploading the package to the pypi test server:
|
||||||
|
|
||||||
twine upload dist/* -r testpypi
|
twine upload dist/* -r testpypi
|
||||||
(pypi suggest using twine as other methods upload files via plaintext.)
|
(pypi suggest using twine as other methods upload files via plaintext.)
|
||||||
@@ -60,13 +61,10 @@ To create the package for pypi.
|
|||||||
|
|
||||||
If making a patch release, double check the bug you are patching is indeed resolved.
|
If making a patch release, double check the bug you are patching is indeed resolved.
|
||||||
|
|
||||||
9. Upload the final version to actual pypi:
|
8. Upload the final version to actual pypi:
|
||||||
twine upload dist/* -r pypi
|
twine upload dist/* -r pypi
|
||||||
|
|
||||||
10. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
|
9. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
|
||||||
|
|
||||||
11. Run `make post-release` then run `make fix-copies`. If you were on a branch for the release,
|
|
||||||
you need to go back to main before executing this.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -12,11 +12,30 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Utility that checks the big table in the file docs/source/en/index.md and potentially updates it.
|
||||||
|
|
||||||
|
Use from the root of the repo with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python utils/check_inits.py
|
||||||
|
```
|
||||||
|
|
||||||
|
for a check that will error in case of inconsistencies (used by `make repo-consistency`).
|
||||||
|
|
||||||
|
To auto-fix issues run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python utils/check_inits.py --fix_and_overwrite
|
||||||
|
```
|
||||||
|
|
||||||
|
which is used by `make fix-copies`.
|
||||||
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import collections
|
import collections
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from transformers.utils import direct_transformers_import
|
from transformers.utils import direct_transformers_import
|
||||||
|
|
||||||
@@ -28,19 +47,28 @@ PATH_TO_DOCS = "docs/source/en"
|
|||||||
REPO_PATH = "."
|
REPO_PATH = "."
|
||||||
|
|
||||||
|
|
||||||
def _find_text_in_file(filename, start_prompt, end_prompt):
|
def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> str:
|
||||||
"""
|
"""
|
||||||
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
|
Find the text in filename between two prompts.
|
||||||
lines.
|
|
||||||
|
Args:
|
||||||
|
filename (`str`): The file to search into.
|
||||||
|
start_prompt (`str`): A string to look for at the start of the content searched.
|
||||||
|
end_prompt (`str`): A string that will mark the end of the content to look for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`: The content between the prompts.
|
||||||
"""
|
"""
|
||||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
|
||||||
# Find the start prompt.
|
# Find the start prompt.
|
||||||
start_index = 0
|
start_index = 0
|
||||||
while not lines[start_index].startswith(start_prompt):
|
while not lines[start_index].startswith(start_prompt):
|
||||||
start_index += 1
|
start_index += 1
|
||||||
start_index += 1
|
start_index += 1
|
||||||
|
|
||||||
|
# Now go until the end prompt.
|
||||||
end_index = start_index
|
end_index = start_index
|
||||||
while not lines[end_index].startswith(end_prompt):
|
while not lines[end_index].startswith(end_prompt):
|
||||||
end_index += 1
|
end_index += 1
|
||||||
@@ -54,9 +82,7 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
|
|||||||
return "".join(lines[start_index:end_index]), start_index, end_index, lines
|
return "".join(lines[start_index:end_index]), start_index, end_index, lines
|
||||||
|
|
||||||
|
|
||||||
# Add here suffixes that are used to identify models, separated by |
|
# Regexes that match TF/Flax/PT model names. Add here suffixes that are used to identify models, separated by |
|
||||||
ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration"
|
|
||||||
# Regexes that match TF/Flax/PT model names.
|
|
||||||
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||||
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||||
# Will match any TF or Flax model too so need to be in an else branch after the two previous regexes.
|
# Will match any TF or Flax model too so need to be in an else branch after the two previous regexes.
|
||||||
@@ -67,22 +93,49 @@ _re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGenerati
|
|||||||
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
|
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
|
||||||
|
|
||||||
|
|
||||||
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
def camel_case_split(identifier: str) -> List[str]:
|
||||||
def camel_case_split(identifier):
|
"""
|
||||||
"Split a camelcased `identifier` into words."
|
Split a camel-cased name into words.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier (`str`): The camel-cased name to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of words in the identifier (as seprated by capital letters).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> camel_case_split("CamelCasedClass")
|
||||||
|
["Camel", "Cased", "Class"]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# Regex thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
||||||
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
|
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
|
||||||
return [m.group(0) for m in matches]
|
return [m.group(0) for m in matches]
|
||||||
|
|
||||||
|
|
||||||
def _center_text(text, width):
|
def _center_text(text: str, width: int) -> str:
|
||||||
|
"""
|
||||||
|
Utility that will add spaces on the left and right of a text to make it centered for a given width.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (`str`): The text to center.
|
||||||
|
width (`int`): The desired length of the result.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`: A text of length `width` with the original `text` in the middle.
|
||||||
|
"""
|
||||||
text_length = 2 if text == "✅" or text == "❌" else len(text)
|
text_length = 2 if text == "✅" or text == "❌" else len(text)
|
||||||
left_indent = (width - text_length) // 2
|
left_indent = (width - text_length) // 2
|
||||||
right_indent = width - text_length - left_indent
|
right_indent = width - text_length - left_indent
|
||||||
return " " * left_indent + text + " " * right_indent
|
return " " * left_indent + text + " " * right_indent
|
||||||
|
|
||||||
|
|
||||||
def get_model_table_from_auto_modules():
|
def get_model_table_from_auto_modules() -> str:
|
||||||
"""Generates an up-to-date model table from the content of the auto modules."""
|
"""
|
||||||
|
Generates an up-to-date model table from the content of the auto modules.
|
||||||
|
"""
|
||||||
# Dictionary model names to config.
|
# Dictionary model names to config.
|
||||||
config_maping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
|
config_maping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
|
||||||
model_name_to_config = {
|
model_name_to_config = {
|
||||||
@@ -92,7 +145,7 @@ def get_model_table_from_auto_modules():
|
|||||||
}
|
}
|
||||||
model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()}
|
model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()}
|
||||||
|
|
||||||
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
|
# Dictionaries flagging if each model prefix has a backend in PT/TF/Flax.
|
||||||
pt_models = collections.defaultdict(bool)
|
pt_models = collections.defaultdict(bool)
|
||||||
tf_models = collections.defaultdict(bool)
|
tf_models = collections.defaultdict(bool)
|
||||||
flax_models = collections.defaultdict(bool)
|
flax_models = collections.defaultdict(bool)
|
||||||
@@ -145,7 +198,13 @@ def get_model_table_from_auto_modules():
|
|||||||
|
|
||||||
|
|
||||||
def check_model_table(overwrite=False):
|
def check_model_table(overwrite=False):
|
||||||
"""Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`."""
|
"""
|
||||||
|
Check the model table in the index.md is consistent with the state of the lib and potentially fix it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
overwrite (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to overwrite the table when it's not up to date.
|
||||||
|
"""
|
||||||
current_table, start_index, end_index, lines = _find_text_in_file(
|
current_table, start_index, end_index, lines = _find_text_in_file(
|
||||||
filename=os.path.join(PATH_TO_DOCS, "index.md"),
|
filename=os.path.join(PATH_TO_DOCS, "index.md"),
|
||||||
start_prompt="<!--This table is updated automatically from the auto modules",
|
start_prompt="<!--This table is updated automatically from the auto modules",
|
||||||
|
|||||||
@@ -12,7 +12,26 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Utility that checks the list of models in the tips in the task-specific pages of the doc is up to date and potentially
|
||||||
|
fixes it.
|
||||||
|
|
||||||
|
Use from the root of the repo with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python utils/check_task_guides.py
|
||||||
|
```
|
||||||
|
|
||||||
|
for a check that will error in case of inconsistencies (used by `make repo-consistency`).
|
||||||
|
|
||||||
|
To auto-fix issues run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python utils/check_task_guides.py --fix_and_overwrite
|
||||||
|
```
|
||||||
|
|
||||||
|
which is used by `make fix-copies`.
|
||||||
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -25,10 +44,17 @@ TRANSFORMERS_PATH = "src/transformers"
|
|||||||
PATH_TO_TASK_GUIDES = "docs/source/en/tasks"
|
PATH_TO_TASK_GUIDES = "docs/source/en/tasks"
|
||||||
|
|
||||||
|
|
||||||
def _find_text_in_file(filename, start_prompt, end_prompt):
|
def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> str:
|
||||||
"""
|
"""
|
||||||
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
|
Find the text in filename between two prompts.
|
||||||
lines.
|
|
||||||
|
Args:
|
||||||
|
filename (`str`): The file to search into.
|
||||||
|
start_prompt (`str`): A string to look for at the start of the content searched.
|
||||||
|
end_prompt (`str`): A string that will mark the end of the content to look for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`: The content between the prompts.
|
||||||
"""
|
"""
|
||||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
@@ -38,6 +64,7 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
|
|||||||
start_index += 1
|
start_index += 1
|
||||||
start_index += 1
|
start_index += 1
|
||||||
|
|
||||||
|
# Now go until the end prompt.
|
||||||
end_index = start_index
|
end_index = start_index
|
||||||
while not lines[end_index].startswith(end_prompt):
|
while not lines[end_index].startswith(end_prompt):
|
||||||
end_index += 1
|
end_index += 1
|
||||||
@@ -54,6 +81,7 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
|
|||||||
# This is to make sure the transformers module imported is the one in the repo.
|
# This is to make sure the transformers module imported is the one in the repo.
|
||||||
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
|
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
|
||||||
|
|
||||||
|
# Map between a task guide and the corresponding auto class.
|
||||||
TASK_GUIDE_TO_MODELS = {
|
TASK_GUIDE_TO_MODELS = {
|
||||||
"asr.md": transformers_module.models.auto.modeling_auto.MODEL_FOR_CTC_MAPPING_NAMES,
|
"asr.md": transformers_module.models.auto.modeling_auto.MODEL_FOR_CTC_MAPPING_NAMES,
|
||||||
"audio_classification.md": transformers_module.models.auto.modeling_auto.MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
"audio_classification.md": transformers_module.models.auto.modeling_auto.MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||||
@@ -81,9 +109,15 @@ SPECIAL_TASK_GUIDE_TO_MODEL_TYPES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_model_list_for_task(task_guide):
|
def get_model_list_for_task(task_guide: str) -> str:
|
||||||
"""
|
"""
|
||||||
Return the list of models supporting given task.
|
Return the list of models supporting a given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_guide (`str`): The name of the task guide to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`: The list of models supporting this task, as links to their respective doc pages separated by commas.
|
||||||
"""
|
"""
|
||||||
model_maping_names = TASK_GUIDE_TO_MODELS[task_guide]
|
model_maping_names = TASK_GUIDE_TO_MODELS[task_guide]
|
||||||
special_model_types = SPECIAL_TASK_GUIDE_TO_MODEL_TYPES.get(task_guide, set())
|
special_model_types = SPECIAL_TASK_GUIDE_TO_MODEL_TYPES.get(task_guide, set())
|
||||||
@@ -95,9 +129,17 @@ def get_model_list_for_task(task_guide):
|
|||||||
return ", ".join([f"[{name}](../model_doc/{code})" for code, name in model_names.items()]) + "\n"
|
return ", ".join([f"[{name}](../model_doc/{code})" for code, name in model_names.items()]) + "\n"
|
||||||
|
|
||||||
|
|
||||||
def check_model_list_for_task(task_guide, overwrite=False):
|
def check_model_list_for_task(task_guide: str, overwrite: bool = False):
|
||||||
"""For a given task guide, checks the model list in the generated tip for consistency with the state of the lib and overwrites if needed."""
|
"""
|
||||||
|
For a given task guide, checks the model list in the generated tip for consistency with the state of the lib and
|
||||||
|
updates it if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_guide (`str`):
|
||||||
|
The name of the task guide to check.
|
||||||
|
overwrite (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to overwrite the table when it's not up to date.
|
||||||
|
"""
|
||||||
current_list, start_index, end_index, lines = _find_text_in_file(
|
current_list, start_index, end_index, lines = _find_text_in_file(
|
||||||
filename=os.path.join(PATH_TO_TASK_GUIDES, task_guide),
|
filename=os.path.join(PATH_TO_TASK_GUIDES, task_guide),
|
||||||
start_prompt="<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->",
|
start_prompt="<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->",
|
||||||
|
|||||||
@@ -12,12 +12,35 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Utility that sorts the imports in the custom inits of Transformers. Transformers uses init files that delay the
|
||||||
|
import of an object to when it's actually needed. This is to avoid the main init importing all models, which would
|
||||||
|
make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with
|
||||||
|
delayed imports have two halves: one definining a dictionary `_import_structure` which maps modules to the name of the
|
||||||
|
objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. `isort` or `ruff`
|
||||||
|
properly sort the second half which looks like traditionl imports, the goal of this script is to sort the first half.
|
||||||
|
|
||||||
|
Use from the root of the repo with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python utils/custom_init_isort.py
|
||||||
|
```
|
||||||
|
|
||||||
|
which will auto-sort the imports (used in `make style`).
|
||||||
|
|
||||||
|
For a check only (as used in `make quality`) run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python utils/custom_init_isort.py --check_only
|
||||||
|
```
|
||||||
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
# Path is defined with the intent you should run this script from the root of the repo.
|
||||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||||
|
|
||||||
# Pattern that looks at the indentation in a line.
|
# Pattern that looks at the indentation in a line.
|
||||||
@@ -32,17 +55,30 @@ _re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$')
|
|||||||
_re_bracket_content = re.compile(r"\[([^\]]+)\]")
|
_re_bracket_content = re.compile(r"\[([^\]]+)\]")
|
||||||
|
|
||||||
|
|
||||||
def get_indent(line):
|
def get_indent(line: str) -> str:
|
||||||
"""Returns the indent in `line`."""
|
"""Returns the indent in given line (as string)."""
|
||||||
search = _re_indent.search(line)
|
search = _re_indent.search(line)
|
||||||
return "" if search is None else search.groups()[0]
|
return "" if search is None else search.groups()[0]
|
||||||
|
|
||||||
|
|
||||||
def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_prompt=None):
|
def split_code_in_indented_blocks(
|
||||||
|
code: str, indent_level: str = "", start_prompt: Optional[str] = None, end_prompt: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Split `code` into its indented blocks, starting at `indent_level`. If provided, begins splitting after
|
Split some code into its indented blocks, starting at a given level.
|
||||||
`start_prompt` and stops at `end_prompt` (but returns what's before `start_prompt` as a first block and what's
|
|
||||||
after `end_prompt` as a last block, so `code` is always the same as joining the result of this function).
|
Args:
|
||||||
|
code (`str`): The code to split.
|
||||||
|
indent_level (`str`): The indent level (as string) to use for identifying the blocks to split.
|
||||||
|
start_prompt (`str`, *optional*): If provided, only starts splitting at the line where this text is.
|
||||||
|
end_prompt (`str`, *optional*): If provided, stops splitting at a line where this text is.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
The text before `start_prompt` or after `end_prompt` (if provided) is not ignored, just not split. The input `code`
|
||||||
|
can thus be retrieved by joining the result.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of blocks.
|
||||||
"""
|
"""
|
||||||
# Let's split the code into lines and move to start_index.
|
# Let's split the code into lines and move to start_index.
|
||||||
index = 0
|
index = 0
|
||||||
@@ -54,12 +90,17 @@ def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_
|
|||||||
else:
|
else:
|
||||||
blocks = []
|
blocks = []
|
||||||
|
|
||||||
# We split into blocks until we get to the `end_prompt` (or the end of the block).
|
# This variable contains the block treated at a given time.
|
||||||
current_block = [lines[index]]
|
current_block = [lines[index]]
|
||||||
index += 1
|
index += 1
|
||||||
|
# We split into blocks until we get to the `end_prompt` (or the end of the file).
|
||||||
while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)):
|
while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)):
|
||||||
|
# We have a non-empty line with the proper indent -> start of a new block
|
||||||
if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level:
|
if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level:
|
||||||
|
# Store the current block in the result and rest. There are two cases: the line is part of the block (like
|
||||||
|
# a closing parenthesis) or not.
|
||||||
if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "):
|
if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "):
|
||||||
|
# Line is part of the current block
|
||||||
current_block.append(lines[index])
|
current_block.append(lines[index])
|
||||||
blocks.append("\n".join(current_block))
|
blocks.append("\n".join(current_block))
|
||||||
if index < len(lines) - 1:
|
if index < len(lines) - 1:
|
||||||
@@ -68,9 +109,11 @@ def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_
|
|||||||
else:
|
else:
|
||||||
current_block = []
|
current_block = []
|
||||||
else:
|
else:
|
||||||
|
# Line is not part of the current block
|
||||||
blocks.append("\n".join(current_block))
|
blocks.append("\n".join(current_block))
|
||||||
current_block = [lines[index]]
|
current_block = [lines[index]]
|
||||||
else:
|
else:
|
||||||
|
# Just add the line to the current block
|
||||||
current_block.append(lines[index])
|
current_block.append(lines[index])
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
@@ -85,8 +128,10 @@ def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_
|
|||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
def ignore_underscore(key):
|
def ignore_underscore_and_lowercase(key: Callable[[Any], str]) -> Callable[[Any], str]:
|
||||||
"Wraps a `key` (that maps an object to string) to lower case and remove underscores."
|
"""
|
||||||
|
Wraps a key function (as used in a sort) to lowercase and ignore underscores.
|
||||||
|
"""
|
||||||
|
|
||||||
def _inner(x):
|
def _inner(x):
|
||||||
return key(x).lower().replace("_", "")
|
return key(x).lower().replace("_", "")
|
||||||
@@ -94,8 +139,21 @@ def ignore_underscore(key):
|
|||||||
return _inner
|
return _inner
|
||||||
|
|
||||||
|
|
||||||
def sort_objects(objects, key=None):
|
def sort_objects(objects: List[Any], key: Optional[Callable[[Any], str]] = None) -> List[Any]:
|
||||||
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
|
"""
|
||||||
|
Sort a list of objects following the rules of isort (all uppercased first, camel-cased second and lower-cased
|
||||||
|
last).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
objects (`List[Any]`):
|
||||||
|
The list of objects to sort.
|
||||||
|
key (`Callable[[Any], str]`, *optional*):
|
||||||
|
A function taking an object as input and returning a string, used to sort them by alphabetical order.
|
||||||
|
If not provided, will default to noop (so a `key` must be provided if the `objects` are not of type string).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[Any]`: The sorted list with the same elements as in the inputs
|
||||||
|
"""
|
||||||
|
|
||||||
# If no key is provided, we use a noop.
|
# If no key is provided, we use a noop.
|
||||||
def noop(x):
|
def noop(x):
|
||||||
@@ -110,18 +168,26 @@ def sort_objects(objects, key=None):
|
|||||||
# Functions begin with a lowercase, they go last.
|
# Functions begin with a lowercase, they go last.
|
||||||
functions = [obj for obj in objects if not key(obj)[0].isupper()]
|
functions = [obj for obj in objects if not key(obj)[0].isupper()]
|
||||||
|
|
||||||
key1 = ignore_underscore(key)
|
# Then we sort each group.
|
||||||
|
key1 = ignore_underscore_and_lowercase(key)
|
||||||
return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1)
|
return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1)
|
||||||
|
|
||||||
|
|
||||||
def sort_objects_in_import(import_statement):
|
def sort_objects_in_import(import_statement: str) -> str:
|
||||||
"""
|
"""
|
||||||
Return the same `import_statement` but with objects properly sorted.
|
Sorts the imports in a single import statement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
import_statement (`str`): The import statement in which to sort the imports.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`: The same as the input, but with objects properly sorted.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This inner function sort imports between [ ].
|
# This inner function sort imports between [ ].
|
||||||
def _replace(match):
|
def _replace(match):
|
||||||
imports = match.groups()[0]
|
imports = match.groups()[0]
|
||||||
|
# If there is one import only, nothing to do.
|
||||||
if "," not in imports:
|
if "," not in imports:
|
||||||
return f"[{imports}]"
|
return f"[{imports}]"
|
||||||
keys = [part.strip().replace('"', "") for part in imports.split(",")]
|
keys = [part.strip().replace('"', "") for part in imports.split(",")]
|
||||||
@@ -165,13 +231,18 @@ def sort_objects_in_import(import_statement):
|
|||||||
return import_statement
|
return import_statement
|
||||||
|
|
||||||
|
|
||||||
def sort_imports(file, check_only=True):
|
def sort_imports(file: str, check_only: bool = True):
|
||||||
"""
|
"""
|
||||||
Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite.
|
Sort the imports defined in the `_import_structure` of a given init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (`str`): The path to the init to check/fix.
|
||||||
|
check_only (`bool`, *optional*, defaults to `True`): Whether or not to just check (and not auto-fix) the init.
|
||||||
"""
|
"""
|
||||||
with open(file, encoding="utf-8") as f:
|
with open(file, encoding="utf-8") as f:
|
||||||
code = f.read()
|
code = f.read()
|
||||||
|
|
||||||
|
# If the file is not a custom init, there is nothing to do.
|
||||||
if "_import_structure" not in code:
|
if "_import_structure" not in code:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -234,6 +305,12 @@ def sort_imports(file, check_only=True):
|
|||||||
|
|
||||||
|
|
||||||
def sort_imports_in_all_inits(check_only=True):
|
def sort_imports_in_all_inits(check_only=True):
|
||||||
|
"""
|
||||||
|
Sort the imports defined in the `_import_structure` of all inits in the repo.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
check_only (`bool`, *optional*, defaults to `True`): Whether or not to just check (and not auto-fix) the init.
|
||||||
|
"""
|
||||||
failures = []
|
failures = []
|
||||||
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
|
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
|
||||||
if "__init__.py" in files:
|
if "__init__.py" in files:
|
||||||
|
|||||||
@@ -12,7 +12,35 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Utility that prepares the repository for releases (or patches) by updating all versions in the relevant places. It
|
||||||
|
also performs some post-release cleanup, by updating the links in the main README to respective model doc pages (from
|
||||||
|
main to stable).
|
||||||
|
|
||||||
|
To prepare for a release, use from the root of the repo on the release branch with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python release.py
|
||||||
|
```
|
||||||
|
|
||||||
|
or use `make pre-release`.
|
||||||
|
|
||||||
|
To prepare for a patch release, use from the root of the repo on the release branch with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python release.py --patch
|
||||||
|
```
|
||||||
|
|
||||||
|
or use `make pre-patch`.
|
||||||
|
|
||||||
|
To do the post-release cleanup, use from the root of the repo on the main branch with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python release.py --post_release
|
||||||
|
```
|
||||||
|
|
||||||
|
or use `make post-release`.
|
||||||
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -20,13 +48,16 @@ import re
|
|||||||
import packaging.version
|
import packaging.version
|
||||||
|
|
||||||
|
|
||||||
|
# All paths are defined with the intent that this script should be run from the root of the repo.
|
||||||
PATH_TO_EXAMPLES = "examples/"
|
PATH_TO_EXAMPLES = "examples/"
|
||||||
|
# This maps a type of file to the pattern to look for when searching where the version is defined, as well as the
|
||||||
|
# template to follow when replacing it with the new version.
|
||||||
REPLACE_PATTERNS = {
|
REPLACE_PATTERNS = {
|
||||||
"examples": (re.compile(r'^check_min_version\("[^"]+"\)\s*$', re.MULTILINE), 'check_min_version("VERSION")\n'),
|
"examples": (re.compile(r'^check_min_version\("[^"]+"\)\s*$', re.MULTILINE), 'check_min_version("VERSION")\n'),
|
||||||
"init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'),
|
"init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'),
|
||||||
"setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'),
|
"setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'),
|
||||||
"doc": (re.compile(r'^(\s*)release\s*=\s*"[^"]+"$', re.MULTILINE), 'release = "VERSION"\n'),
|
|
||||||
}
|
}
|
||||||
|
# This maps a type of file to its path in Transformers
|
||||||
REPLACE_FILES = {
|
REPLACE_FILES = {
|
||||||
"init": "src/transformers/__init__.py",
|
"init": "src/transformers/__init__.py",
|
||||||
"setup": "setup.py",
|
"setup": "setup.py",
|
||||||
@@ -34,19 +65,31 @@ REPLACE_FILES = {
|
|||||||
README_FILE = "README.md"
|
README_FILE = "README.md"
|
||||||
|
|
||||||
|
|
||||||
def update_version_in_file(fname, version, pattern):
|
def update_version_in_file(fname: str, version: str, file_type: str):
|
||||||
"""Update the version in one file using a specific pattern."""
|
"""
|
||||||
|
Update the version of Transformers in one file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fname (`str`): The path to the file where we want to update the version.
|
||||||
|
version (`str`): The new version to set in the file.
|
||||||
|
file_type (`str`): The type of the file (should be a key in `REPLACE_PATTERNS`).
|
||||||
|
"""
|
||||||
with open(fname, "r", encoding="utf-8", newline="\n") as f:
|
with open(fname, "r", encoding="utf-8", newline="\n") as f:
|
||||||
code = f.read()
|
code = f.read()
|
||||||
re_pattern, replace = REPLACE_PATTERNS[pattern]
|
re_pattern, replace = REPLACE_PATTERNS[file_type]
|
||||||
replace = replace.replace("VERSION", version)
|
replace = replace.replace("VERSION", version)
|
||||||
code = re_pattern.sub(replace, code)
|
code = re_pattern.sub(replace, code)
|
||||||
with open(fname, "w", encoding="utf-8", newline="\n") as f:
|
with open(fname, "w", encoding="utf-8", newline="\n") as f:
|
||||||
f.write(code)
|
f.write(code)
|
||||||
|
|
||||||
|
|
||||||
def update_version_in_examples(version):
|
def update_version_in_examples(version: str):
|
||||||
"""Update the version in all examples files."""
|
"""
|
||||||
|
Update the version in all examples files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version (`str`): The new version to set in the examples.
|
||||||
|
"""
|
||||||
for folder, directories, fnames in os.walk(PATH_TO_EXAMPLES):
|
for folder, directories, fnames in os.walk(PATH_TO_EXAMPLES):
|
||||||
# Removing some of the folders with non-actively maintained examples from the walk
|
# Removing some of the folders with non-actively maintained examples from the walk
|
||||||
if "research_projects" in directories:
|
if "research_projects" in directories:
|
||||||
@@ -55,19 +98,28 @@ def update_version_in_examples(version):
|
|||||||
directories.remove("legacy")
|
directories.remove("legacy")
|
||||||
for fname in fnames:
|
for fname in fnames:
|
||||||
if fname.endswith(".py"):
|
if fname.endswith(".py"):
|
||||||
update_version_in_file(os.path.join(folder, fname), version, pattern="examples")
|
update_version_in_file(os.path.join(folder, fname), version, file_type="examples")
|
||||||
|
|
||||||
|
|
||||||
def global_version_update(version, patch=False):
|
def global_version_update(version: str, patch: bool = False):
|
||||||
"""Update the version in all needed files."""
|
"""
|
||||||
|
Update the version in all needed files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version (`str`): The new version to set everywhere.
|
||||||
|
patch (`bool`, *optional*, defaults to `False`): Whether or not this is a patch release.
|
||||||
|
"""
|
||||||
for pattern, fname in REPLACE_FILES.items():
|
for pattern, fname in REPLACE_FILES.items():
|
||||||
update_version_in_file(fname, version, pattern)
|
update_version_in_file(fname, version, pattern)
|
||||||
if not patch:
|
if not patch:
|
||||||
|
# We don't update the version in the examples for patch releases.
|
||||||
update_version_in_examples(version)
|
update_version_in_examples(version)
|
||||||
|
|
||||||
|
|
||||||
def clean_main_ref_in_model_list():
|
def clean_main_ref_in_model_list():
|
||||||
"""Replace the links from main doc tp stable doc in the model list of the README."""
|
"""
|
||||||
|
Replace the links from main doc to stable doc in the model list of the README.
|
||||||
|
"""
|
||||||
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
|
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
|
||||||
_start_prompt = "🤗 Transformers currently provides the following architectures"
|
_start_prompt = "🤗 Transformers currently provides the following architectures"
|
||||||
_end_prompt = "1. Want to contribute a new model?"
|
_end_prompt = "1. Want to contribute a new model?"
|
||||||
@@ -94,16 +146,26 @@ def clean_main_ref_in_model_list():
|
|||||||
f.writelines(lines)
|
f.writelines(lines)
|
||||||
|
|
||||||
|
|
||||||
def get_version():
|
def get_version() -> packaging.version.Version:
|
||||||
"""Reads the current version in the __init__."""
|
"""
|
||||||
|
Reads the current version in the main __init__.
|
||||||
|
"""
|
||||||
with open(REPLACE_FILES["init"], "r") as f:
|
with open(REPLACE_FILES["init"], "r") as f:
|
||||||
code = f.read()
|
code = f.read()
|
||||||
default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0]
|
default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0]
|
||||||
return packaging.version.parse(default_version)
|
return packaging.version.parse(default_version)
|
||||||
|
|
||||||
|
|
||||||
def pre_release_work(patch=False):
|
def pre_release_work(patch: bool = False):
|
||||||
"""Do all the necessary pre-release steps."""
|
"""
|
||||||
|
Do all the necessary pre-release steps:
|
||||||
|
- figure out the next minor release version and ask confirmation
|
||||||
|
- update the version eveywhere
|
||||||
|
- clean-up the model list in the main README
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patch (`bool`, *optional*, defaults to `False`): Whether or not this is a patch release.
|
||||||
|
"""
|
||||||
# First let's get the default version: base version if we are in dev, bump minor otherwise.
|
# First let's get the default version: base version if we are in dev, bump minor otherwise.
|
||||||
default_version = get_version()
|
default_version = get_version()
|
||||||
if patch and default_version.is_devrelease:
|
if patch and default_version.is_devrelease:
|
||||||
@@ -115,7 +177,7 @@ def pre_release_work(patch=False):
|
|||||||
else:
|
else:
|
||||||
default_version = f"{default_version.major}.{default_version.minor + 1}.0"
|
default_version = f"{default_version.major}.{default_version.minor + 1}.0"
|
||||||
|
|
||||||
# Now let's ask nicely if that's the right one.
|
# Now let's ask nicely if we have found the right version.
|
||||||
version = input(f"Which version are you releasing? [{default_version}]")
|
version = input(f"Which version are you releasing? [{default_version}]")
|
||||||
if len(version) == 0:
|
if len(version) == 0:
|
||||||
version = default_version
|
version = default_version
|
||||||
@@ -128,7 +190,12 @@ def pre_release_work(patch=False):
|
|||||||
|
|
||||||
|
|
||||||
def post_release_work():
|
def post_release_work():
|
||||||
"""Do all the necesarry post-release steps."""
|
"""
|
||||||
|
Do all the necesarry post-release steps:
|
||||||
|
- figure out the next dev version and ask confirmation
|
||||||
|
- update the version eveywhere
|
||||||
|
- clean-up the model list in the main README
|
||||||
|
"""
|
||||||
# First let's get the current version
|
# First let's get the current version
|
||||||
current_version = get_version()
|
current_version = get_version()
|
||||||
dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0"
|
dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0"
|
||||||
|
|||||||
@@ -12,12 +12,30 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Utility that sorts the names in the auto mappings defines in the auto modules in alphabetical order.
|
||||||
|
|
||||||
|
Use from the root of the repo with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python utils/sort_auto_mappings.py
|
||||||
|
```
|
||||||
|
|
||||||
|
to auto-fix all the auto mappings (used in `make style`).
|
||||||
|
|
||||||
|
To only check if the mappings are properly sorted (as used in `make quality`), do:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python utils/sort_auto_mappings.py --check_only
|
||||||
|
```
|
||||||
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# Path are set with the intent you should run this script from the root of the repo.
|
||||||
PATH_TO_AUTO_MODULE = "src/transformers/models/auto"
|
PATH_TO_AUTO_MODULE = "src/transformers/models/auto"
|
||||||
|
|
||||||
|
|
||||||
@@ -28,7 +46,18 @@ _re_intro_mapping = re.compile(r"[A-Z_]+_MAPPING(\s+|_[A-Z_]+\s+)=\s+OrderedDict
|
|||||||
_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
|
_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
|
||||||
|
|
||||||
|
|
||||||
def sort_auto_mapping(fname, overwrite: bool = False):
|
def sort_auto_mapping(fname: str, overwrite: bool = False) -> Optional[bool]:
|
||||||
|
"""
|
||||||
|
Sort all auto mappings in a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fname (`str`): The name of the file where we want to sort auto-mappings.
|
||||||
|
overwrite (`bool`, *optional*, defaults to `False`): Whether or not to fix and overwrite the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Optional[bool]`: Returns `None` if `overwrite=True`. Otherwise returns `True` if the file has an auto-mapping
|
||||||
|
improperly sorted, `False` if the file is okay.
|
||||||
|
"""
|
||||||
with open(fname, "r", encoding="utf-8") as f:
|
with open(fname, "r", encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
@@ -37,8 +66,8 @@ def sort_auto_mapping(fname, overwrite: bool = False):
|
|||||||
line_idx = 0
|
line_idx = 0
|
||||||
while line_idx < len(lines):
|
while line_idx < len(lines):
|
||||||
if _re_intro_mapping.search(lines[line_idx]) is not None:
|
if _re_intro_mapping.search(lines[line_idx]) is not None:
|
||||||
indent = len(re.search(r"^(\s*)\S", lines[line_idx]).groups()[0]) + 8
|
|
||||||
# Start of a new mapping!
|
# Start of a new mapping!
|
||||||
|
indent = len(re.search(r"^(\s*)\S", lines[line_idx]).groups()[0]) + 8
|
||||||
while not lines[line_idx].startswith(" " * indent + "("):
|
while not lines[line_idx].startswith(" " * indent + "("):
|
||||||
new_lines.append(lines[line_idx])
|
new_lines.append(lines[line_idx])
|
||||||
line_idx += 1
|
line_idx += 1
|
||||||
@@ -65,11 +94,17 @@ def sort_auto_mapping(fname, overwrite: bool = False):
|
|||||||
if overwrite:
|
if overwrite:
|
||||||
with open(fname, "w", encoding="utf-8") as f:
|
with open(fname, "w", encoding="utf-8") as f:
|
||||||
f.write("\n".join(new_lines))
|
f.write("\n".join(new_lines))
|
||||||
elif "\n".join(new_lines) != content:
|
else:
|
||||||
return True
|
return "\n".join(new_lines) != content
|
||||||
|
|
||||||
|
|
||||||
def sort_all_auto_mappings(overwrite: bool = False):
|
def sort_all_auto_mappings(overwrite: bool = False):
|
||||||
|
"""
|
||||||
|
Sort all auto mappings in the library.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
overwrite (`bool`, *optional*, defaults to `False`): Whether or not to fix and overwrite the file.
|
||||||
|
"""
|
||||||
fnames = [os.path.join(PATH_TO_AUTO_MODULE, f) for f in os.listdir(PATH_TO_AUTO_MODULE) if f.endswith(".py")]
|
fnames = [os.path.join(PATH_TO_AUTO_MODULE, f) for f in os.listdir(PATH_TO_AUTO_MODULE) if f.endswith(".py")]
|
||||||
diffs = [sort_auto_mapping(fname, overwrite=overwrite) for fname in fnames]
|
diffs = [sort_auto_mapping(fname, overwrite=overwrite) for fname in fnames]
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Welcome to tests_fetcher V2.
|
Welcome to tests_fetcher V2.
|
||||||
|
|
||||||
This util is designed to fetch tests to run on a PR so that only the tests impacted by the modifications are run, and
|
This util is designed to fetch tests to run on a PR so that only the tests impacted by the modifications are run, and
|
||||||
when too many models are being impacted, only run the tests of a subset of core models. It works like this.
|
when too many models are being impacted, only run the tests of a subset of core models. It works like this.
|
||||||
|
|
||||||
|
|||||||
@@ -12,12 +12,28 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Utility that updates the metadata of the Transformers library in the repository `huggingface/transformers-metadata`.
|
||||||
|
|
||||||
|
Usage for an update (as used by the GitHub action `update_metadata`):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python utils/update_metadata.py --token <token> --commit_sha <commit_sha>
|
||||||
|
```
|
||||||
|
|
||||||
|
Usage to check all pipelines are properly defined in the constant `PIPELINE_TAGS_AND_AUTO_MODELS` of this script, so
|
||||||
|
that new pipelines are properly added as metadata (as used in `make repo-consistency`):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python utils/update_metadata.py --check-only
|
||||||
|
```
|
||||||
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import collections
|
import collections
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@@ -102,14 +118,29 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
def camel_case_split(identifier: str) -> List[str]:
|
||||||
def camel_case_split(identifier):
|
"""
|
||||||
"Split a camelcased `identifier` into words."
|
Split a camel-cased name into words.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier (`str`): The camel-cased name to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of words in the identifier (as seprated by capital letters).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> camel_case_split("CamelCasedClass")
|
||||||
|
["Camel", "Cased", "Class"]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# Regex thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
||||||
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
|
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
|
||||||
return [m.group(0) for m in matches]
|
return [m.group(0) for m in matches]
|
||||||
|
|
||||||
|
|
||||||
def get_frameworks_table():
|
def get_frameworks_table() -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
Generates a dataframe containing the supported auto classes for each model type, using the content of the auto
|
Generates a dataframe containing the supported auto classes for each model type, using the content of the auto
|
||||||
modules.
|
modules.
|
||||||
@@ -155,7 +186,8 @@ def get_frameworks_table():
|
|||||||
data["tensorflow"] = [tf_models[t] for t in all_models]
|
data["tensorflow"] = [tf_models[t] for t in all_models]
|
||||||
data["flax"] = [flax_models[t] for t in all_models]
|
data["flax"] = [flax_models[t] for t in all_models]
|
||||||
|
|
||||||
# Now let's use the auto-mapping names to make sure
|
# Now let's find the right processing class for each model. In order we check if there is a Processor, then a
|
||||||
|
# Tokenizer, then a FeatureExtractor, then an ImageProcessor
|
||||||
processors = {}
|
processors = {}
|
||||||
for t in all_models:
|
for t in all_models:
|
||||||
if t in transformers_module.models.auto.processing_auto.PROCESSOR_MAPPING_NAMES:
|
if t in transformers_module.models.auto.processing_auto.PROCESSOR_MAPPING_NAMES:
|
||||||
@@ -164,6 +196,8 @@ def get_frameworks_table():
|
|||||||
processors[t] = "AutoTokenizer"
|
processors[t] = "AutoTokenizer"
|
||||||
elif t in transformers_module.models.auto.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES:
|
elif t in transformers_module.models.auto.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES:
|
||||||
processors[t] = "AutoFeatureExtractor"
|
processors[t] = "AutoFeatureExtractor"
|
||||||
|
elif t in transformers_module.models.auto.image_processing_auto.IMAGE_PROCESSOR_MAPPING_NAMES:
|
||||||
|
processors[t] = "AutoFeatureExtractor"
|
||||||
else:
|
else:
|
||||||
# Default to AutoTokenizer if a model has nothing, for backward compatibility.
|
# Default to AutoTokenizer if a model has nothing, for backward compatibility.
|
||||||
processors[t] = "AutoTokenizer"
|
processors[t] = "AutoTokenizer"
|
||||||
@@ -173,10 +207,17 @@ def get_frameworks_table():
|
|||||||
return pd.DataFrame(data)
|
return pd.DataFrame(data)
|
||||||
|
|
||||||
|
|
||||||
def update_pipeline_and_auto_class_table(table):
|
def update_pipeline_and_auto_class_table(table: Dict[str, Tuple[str, str]]) -> Dict[str, Tuple[str, str]]:
|
||||||
"""
|
"""
|
||||||
Update the table of model class to (pipeline_tag, auto_class) without removing old keys if they don't exist
|
Update the table maping models to pipelines and auto classes without removing old keys if they don't exist anymore.
|
||||||
anymore.
|
|
||||||
|
Args:
|
||||||
|
table (`Dict[str, Tuple[str, str]]`):
|
||||||
|
The existing table mapping model names to a tuple containing the pipeline tag and the auto-class name with
|
||||||
|
which they should be used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Dict[str, Tuple[str, str]]`: The updated table in the same format.
|
||||||
"""
|
"""
|
||||||
auto_modules = [
|
auto_modules = [
|
||||||
transformers_module.models.auto.modeling_auto,
|
transformers_module.models.auto.modeling_auto,
|
||||||
@@ -205,9 +246,13 @@ def update_pipeline_and_auto_class_table(table):
|
|||||||
return table
|
return table
|
||||||
|
|
||||||
|
|
||||||
def update_metadata(token, commit_sha):
|
def update_metadata(token: str, commit_sha: str):
|
||||||
"""
|
"""
|
||||||
Update the metadata for the Transformers repo.
|
Update the metadata for the Transformers repo in `huggingface/transformers-metadata`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (`str`): A valid token giving write access to `huggingface/transformers-metadata`.
|
||||||
|
commit_sha (`str`): The commit SHA on Transformers corresponding to this update.
|
||||||
"""
|
"""
|
||||||
frameworks_table = get_frameworks_table()
|
frameworks_table = get_frameworks_table()
|
||||||
frameworks_dataset = Dataset.from_pandas(frameworks_table)
|
frameworks_dataset = Dataset.from_pandas(frameworks_table)
|
||||||
@@ -255,6 +300,9 @@ def update_metadata(token, commit_sha):
|
|||||||
|
|
||||||
|
|
||||||
def check_pipeline_tags():
|
def check_pipeline_tags():
|
||||||
|
"""
|
||||||
|
Check all pipeline tags are properly defined in the `PIPELINE_TAGS_AND_AUTO_MODELS` constant of this script.
|
||||||
|
"""
|
||||||
in_table = {tag: cls for tag, _, cls in PIPELINE_TAGS_AND_AUTO_MODELS}
|
in_table = {tag: cls for tag, _, cls in PIPELINE_TAGS_AND_AUTO_MODELS}
|
||||||
pipeline_tasks = transformers_module.pipelines.SUPPORTED_TASKS
|
pipeline_tasks = transformers_module.pipelines.SUPPORTED_TASKS
|
||||||
missing = []
|
missing = []
|
||||||
|
|||||||
Reference in New Issue
Block a user