Big model table (#8774)
* First draft * Styling * With all changes staged * Update docs/source/index.rst Co-authored-by: Julien Chaumond <chaumond@gmail.com> * Styling Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -2,6 +2,15 @@
|
|||||||
|
|
||||||
/* Colab dropdown */
|
/* Colab dropdown */
|
||||||
|
|
||||||
|
table.center-aligned-table td {
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
table.center-aligned-table th {
|
||||||
|
text-align: center;
|
||||||
|
vertical-align: middle;
|
||||||
|
}
|
||||||
|
|
||||||
.colab-dropdown {
|
.colab-dropdown {
|
||||||
position: relative;
|
position: relative;
|
||||||
display: inline-block;
|
display: inline-block;
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ Choose the right framework for every part of a model's lifetime:
|
|||||||
- Move a single model between TF2.0/PyTorch frameworks at will
|
- Move a single model between TF2.0/PyTorch frameworks at will
|
||||||
- Seamlessly pick the right framework for training, evaluation, production
|
- Seamlessly pick the right framework for training, evaluation, production
|
||||||
|
|
||||||
|
Experimental support for Flax with a few models right now, expected to grow in the coming months.
|
||||||
|
|
||||||
Contents
|
Contents
|
||||||
-----------------------------------------------------------------------------------------------------------------------
|
-----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
@@ -52,8 +54,8 @@ The documentation is organized in five parts:
|
|||||||
- **MODELS** for the classes and functions related to each model implemented in the library.
|
- **MODELS** for the classes and functions related to each model implemented in the library.
|
||||||
- **INTERNAL HELPERS** for the classes and functions we use internally.
|
- **INTERNAL HELPERS** for the classes and functions we use internally.
|
||||||
|
|
||||||
The library currently contains PyTorch and Tensorflow implementations, pre-trained model weights, usage scripts and
|
The library currently contains PyTorch, Tensorflow and Flax implementations, pretrained model weights, usage scripts
|
||||||
conversion utilities for the following models:
|
and conversion utilities for the following models:
|
||||||
|
|
||||||
..
|
..
|
||||||
This list is updated automatically from the README with `make fix-copies`. Do not update manually!
|
This list is updated automatically from the README with `make fix-copies`. Do not update manually!
|
||||||
@@ -166,6 +168,95 @@ conversion utilities for the following models:
|
|||||||
34. `Other community models <https://huggingface.co/models>`__, contributed by the `community
|
34. `Other community models <https://huggingface.co/models>`__, contributed by the `community
|
||||||
<https://huggingface.co/users>`__.
|
<https://huggingface.co/users>`__.
|
||||||
|
|
||||||
|
|
||||||
|
The table below represents the current support in the library for each of those models, whether they have a Python
|
||||||
|
tokenizer (called "slow"). A "fast" tokenizer backed by the 🤗 Tokenizers library, whether they have support in PyTorch,
|
||||||
|
TensorFlow and/or Flax.
|
||||||
|
|
||||||
|
..
|
||||||
|
This table is updated automatically from the auto modules with `make fix-copies`. Do not update manually!
|
||||||
|
|
||||||
|
.. rst-class:: center-aligned-table
|
||||||
|
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| Model | Tokenizer slow | Tokenizer fast | PyTorch support | TensorFlow support | Flax Support |
|
||||||
|
+=============================+================+================+=================+====================+==============+
|
||||||
|
| ALBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| BART | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| BERT | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| Bert Generation | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| DeBERTa | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| DistilBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| Encoder decoder | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| FlauBERT | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| LayoutLM | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| Marian | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| MobileBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| OpenAI GPT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| Pegasus | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| RAG | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| T5 | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| mBART | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| mT5 | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
:caption: Get started
|
:caption: Get started
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
|||||||
from .models.auto import (
|
from .models.auto import (
|
||||||
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
|
MODEL_NAMES_MAPPING,
|
||||||
TOKENIZER_MAPPING,
|
TOKENIZER_MAPPING,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@@ -880,6 +881,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
|
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
||||||
from .models.bert import FlaxBertModel
|
from .models.bert import FlaxBertModel
|
||||||
from .models.roberta import FlaxRobertaModel
|
from .models.roberta import FlaxRobertaModel
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||||
# module, but to preserve other warnings. So, don't check this module at all.
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
from ...file_utils import is_tf_available, is_torch_available
|
from ...file_utils import is_flax_available, is_tf_available, is_torch_available
|
||||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
|
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
|
||||||
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
@@ -57,3 +57,6 @@ if is_tf_available():
|
|||||||
TFAutoModelForTokenClassification,
|
TFAutoModelForTokenClassification,
|
||||||
TFAutoModelWithLMHead,
|
TFAutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
from .modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
|||||||
for key, value, in pretrained_map.items()
|
for key, value, in pretrained_map.items()
|
||||||
)
|
)
|
||||||
|
|
||||||
MODEL_MAPPING = OrderedDict(
|
FLAX_MODEL_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
(RobertaConfig, FlaxRobertaModel),
|
(RobertaConfig, FlaxRobertaModel),
|
||||||
(BertConfig, FlaxBertModel),
|
(BertConfig, FlaxBertModel),
|
||||||
@@ -79,13 +79,13 @@ class FlaxAutoModel(object):
|
|||||||
model = FlaxAutoModel.from_config(config)
|
model = FlaxAutoModel.from_config(config)
|
||||||
# E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
# E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
"""
|
"""
|
||||||
for config_class, model_class in MODEL_MAPPING.items():
|
for config_class, model_class in FLAX_MODEL_MAPPING.items():
|
||||||
if isinstance(config, config_class):
|
if isinstance(config, config_class):
|
||||||
return model_class(config)
|
return model_class(config)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized configuration class {config.__class__} "
|
f"Unrecognized configuration class {config.__class__} "
|
||||||
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
|
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
|
||||||
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}."
|
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}."
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -173,11 +173,11 @@ class FlaxAutoModel(object):
|
|||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
for config_class, model_class in MODEL_MAPPING.items():
|
for config_class, model_class in FLAX_MODEL_MAPPING.items():
|
||||||
if isinstance(config, config_class):
|
if isinstance(config, config_class):
|
||||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized configuration class {config.__class__} "
|
f"Unrecognized configuration class {config.__class__} "
|
||||||
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
|
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
|
||||||
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}"
|
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,6 +2,18 @@
|
|||||||
from ..file_utils import requires_flax
|
from ..file_utils import requires_flax
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_MODEL_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxAutoModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertModel:
|
class FlaxBertModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_flax(self)
|
requires_flax(self)
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
|
import importlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -250,20 +251,21 @@ def convert_to_rst(model_list, max_per_line=None):
|
|||||||
return "\n".join(result)
|
return "\n".join(result)
|
||||||
|
|
||||||
|
|
||||||
def check_model_list_copy(overwrite=False, max_per_line=119):
|
def _find_text_in_file(filename, start_prompt, end_prompt):
|
||||||
""" Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """
|
"""
|
||||||
_start_prompt = " This list is updated automatically from the README"
|
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
|
||||||
_end_prompt = ".. toctree::"
|
lines.
|
||||||
with open(os.path.join(PATH_TO_DOCS, "index.rst"), "r", encoding="utf-8", newline="\n") as f:
|
"""
|
||||||
|
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
# Find the start of the list.
|
# Find the start prompt.
|
||||||
start_index = 0
|
start_index = 0
|
||||||
while not lines[start_index].startswith(_start_prompt):
|
while not lines[start_index].startswith(start_prompt):
|
||||||
start_index += 1
|
start_index += 1
|
||||||
start_index += 1
|
start_index += 1
|
||||||
|
|
||||||
end_index = start_index
|
end_index = start_index
|
||||||
while not lines[end_index].startswith(_end_prompt):
|
while not lines[end_index].startswith(end_prompt):
|
||||||
end_index += 1
|
end_index += 1
|
||||||
end_index -= 1
|
end_index -= 1
|
||||||
|
|
||||||
@@ -272,8 +274,16 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
|
|||||||
while len(lines[end_index]) <= 1:
|
while len(lines[end_index]) <= 1:
|
||||||
end_index -= 1
|
end_index -= 1
|
||||||
end_index += 1
|
end_index += 1
|
||||||
|
return "".join(lines[start_index:end_index]), start_index, end_index, lines
|
||||||
|
|
||||||
rst_list = "".join(lines[start_index:end_index])
|
|
||||||
|
def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||||
|
""" Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """
|
||||||
|
rst_list, start_index, end_index, lines = _find_text_in_file(
|
||||||
|
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
|
||||||
|
start_prompt=" This list is updated automatically from the README",
|
||||||
|
end_prompt="The table below represents the current support",
|
||||||
|
)
|
||||||
md_list = get_model_list()
|
md_list = get_model_list()
|
||||||
converted_list = convert_to_rst(md_list, max_per_line=max_per_line)
|
converted_list = convert_to_rst(md_list, max_per_line=max_per_line)
|
||||||
|
|
||||||
@@ -283,7 +293,116 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
|
|||||||
f.writelines(lines[:start_index] + [converted_list] + lines[end_index:])
|
f.writelines(lines[:start_index] + [converted_list] + lines[end_index:])
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The model list in the README changed and the list in `index.rst` has not been updated. Run `make fix-copies` to fix this."
|
"The model list in the README changed and the list in `index.rst` has not been updated. Run "
|
||||||
|
"`make fix-copies` to fix this."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _center_text(text, width):
|
||||||
|
text_length = 2 if text == "✅" or text == "❌" else len(text)
|
||||||
|
left_indent = (width - text_length) // 2
|
||||||
|
right_indent = width - text_length - left_indent
|
||||||
|
return " " * left_indent + text + " " * right_indent
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_table_from_auto_modules():
|
||||||
|
"""Generates an up-to-date model table from the content of the auto modules."""
|
||||||
|
# This is to make sure the transformers module imported is the one in the repo.
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"transformers",
|
||||||
|
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
|
||||||
|
submodule_search_locations=[TRANSFORMERS_PATH],
|
||||||
|
)
|
||||||
|
transformers = spec.loader.load_module()
|
||||||
|
|
||||||
|
# Dictionary model names to config.
|
||||||
|
model_name_to_config = {
|
||||||
|
name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items()
|
||||||
|
}
|
||||||
|
# All tokenizer tuples.
|
||||||
|
tokenizers = {
|
||||||
|
name: transformers.TOKENIZER_MAPPING[config]
|
||||||
|
for name, config in model_name_to_config.items()
|
||||||
|
if config in transformers.TOKENIZER_MAPPING
|
||||||
|
}
|
||||||
|
# Model names that a slow/fast tokenizer.
|
||||||
|
has_slow_tokenizers = [name for name, tok in tokenizers.items() if tok[0] is not None]
|
||||||
|
has_fast_tokenizers = [name for name, tok in tokenizers.items() if tok[1] is not None]
|
||||||
|
|
||||||
|
# Model names that have a PyTorch implementation.
|
||||||
|
has_pt_model = [name for name, config in model_name_to_config.items() if config in transformers.MODEL_MAPPING]
|
||||||
|
# Some of the GenerationModel don't have a base model.
|
||||||
|
has_pt_model.extend(
|
||||||
|
[
|
||||||
|
name
|
||||||
|
for name, config in model_name_to_config.items()
|
||||||
|
if config in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Special exception for RAG
|
||||||
|
has_pt_model.append("RAG")
|
||||||
|
|
||||||
|
# Model names that have a TensorFlow implementation.
|
||||||
|
has_tf_model = [name for name, config in model_name_to_config.items() if config in transformers.TF_MODEL_MAPPING]
|
||||||
|
# Some of the GenerationModel don't have a base model.
|
||||||
|
has_tf_model.extend(
|
||||||
|
[
|
||||||
|
name
|
||||||
|
for name, config in model_name_to_config.items()
|
||||||
|
if config in transformers.TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model names that have a Flax implementation.
|
||||||
|
has_flax_model = [
|
||||||
|
name for name, config in model_name_to_config.items() if config in transformers.FLAX_MODEL_MAPPING
|
||||||
|
]
|
||||||
|
|
||||||
|
# Let's build that table!
|
||||||
|
model_names = list(model_name_to_config.keys())
|
||||||
|
model_names.sort()
|
||||||
|
columns = ["Model", "Tokenizer slow", "Tokenizer fast", "PyTorch support", "TensorFlow support", "Flax Support"]
|
||||||
|
# We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side).
|
||||||
|
widths = [len(c) + 2 for c in columns]
|
||||||
|
widths[0] = max([len(name) for name in model_names]) + 2
|
||||||
|
|
||||||
|
# Rst table per se
|
||||||
|
table = ".. rst-class:: center-aligned-table\n\n"
|
||||||
|
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
|
||||||
|
table += "|" + "|".join([_center_text(c, w) for c, w in zip(columns, widths)]) + "|\n"
|
||||||
|
table += "+" + "+".join(["=" * w for w in widths]) + "+\n"
|
||||||
|
|
||||||
|
check = {True: "✅", False: "❌"}
|
||||||
|
for name in model_names:
|
||||||
|
line = [
|
||||||
|
name,
|
||||||
|
check[name in has_slow_tokenizers],
|
||||||
|
check[name in has_fast_tokenizers],
|
||||||
|
check[name in has_pt_model],
|
||||||
|
check[name in has_tf_model],
|
||||||
|
check[name in has_flax_model],
|
||||||
|
]
|
||||||
|
table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n"
|
||||||
|
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_table(overwrite=False):
|
||||||
|
""" Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`. """
|
||||||
|
current_table, start_index, end_index, lines = _find_text_in_file(
|
||||||
|
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
|
||||||
|
start_prompt=" This table is updated automatically from the auto module",
|
||||||
|
end_prompt=".. toctree::",
|
||||||
|
)
|
||||||
|
new_table = get_model_table_from_auto_modules()
|
||||||
|
|
||||||
|
if current_table != new_table:
|
||||||
|
if overwrite:
|
||||||
|
with open(os.path.join(PATH_TO_DOCS, "index.rst"), "w", encoding="utf-8", newline="\n") as f:
|
||||||
|
f.writelines(lines[:start_index] + [new_table] + lines[end_index:])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The model table in the `index.rst` has not been updated. Run `make fix-copies` to fix this."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -293,3 +412,4 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
check_copies(args.fix_and_overwrite)
|
check_copies(args.fix_and_overwrite)
|
||||||
|
check_model_table(args.fix_and_overwrite)
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ def get_model_modules():
|
|||||||
"modeling_outputs",
|
"modeling_outputs",
|
||||||
"modeling_retribert",
|
"modeling_retribert",
|
||||||
"modeling_utils",
|
"modeling_utils",
|
||||||
|
"modeling_flax_auto",
|
||||||
"modeling_flax_utils",
|
"modeling_flax_utils",
|
||||||
"modeling_transfo_xl_utilities",
|
"modeling_transfo_xl_utilities",
|
||||||
"modeling_tf_auto",
|
"modeling_tf_auto",
|
||||||
|
|||||||
Reference in New Issue
Block a user