Make the big table creation/check platform independent (#8856)

This commit is contained in:
Sylvain Gugger
2020-12-01 11:45:57 -05:00
committed by GitHub
parent d366228df1
commit c0df963ee1
6 changed files with 92 additions and 41 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import argparse
import collections
import glob
import importlib
import os
@@ -298,6 +299,22 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
)
# Add here suffixes that are used to identify models, seperated by |
ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration"
# Regexes that match TF/Flax/PT model names.
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes.
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
def camel_case_split(identifier):
"Split a camelcased `identifier` into words."
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
return [m.group(0) for m in matches]
def _center_text(text, width):
text_length = 2 if text == "" or text == "" else len(text)
left_indent = (width - text_length) // 2
@@ -319,44 +336,43 @@ def get_model_table_from_auto_modules():
model_name_to_config = {
name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items()
}
# All tokenizer tuples.
tokenizers = {
name: transformers.TOKENIZER_MAPPING[config]
for name, config in model_name_to_config.items()
if config in transformers.TOKENIZER_MAPPING
model_name_to_prefix = {
name: config.__name__.replace("Config", "") for name, config in model_name_to_config.items()
}
# Model names that a slow/fast tokenizer.
has_slow_tokenizers = [name for name, tok in tokenizers.items() if tok[0] is not None]
has_fast_tokenizers = [name for name, tok in tokenizers.items() if tok[1] is not None]
# Model names that have a PyTorch implementation.
has_pt_model = [name for name, config in model_name_to_config.items() if config in transformers.MODEL_MAPPING]
# Some of the GenerationModel don't have a base model.
has_pt_model.extend(
[
name
for name, config in model_name_to_config.items()
if config in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
]
)
# Special exception for RAG
has_pt_model.append("RAG")
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
slow_tokenizers = collections.defaultdict(bool)
fast_tokenizers = collections.defaultdict(bool)
pt_models = collections.defaultdict(bool)
tf_models = collections.defaultdict(bool)
flax_models = collections.defaultdict(bool)
# Model names that have a TensorFlow implementation.
has_tf_model = [name for name, config in model_name_to_config.items() if config in transformers.TF_MODEL_MAPPING]
# Some of the GenerationModel don't have a base model.
has_tf_model.extend(
[
name
for name, config in model_name_to_config.items()
if config in transformers.TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
]
)
# Let's lookup through all transformers object (once).
for attr_name in dir(transformers):
lookup_dict = None
if attr_name.endswith("Tokenizer"):
lookup_dict = slow_tokenizers
attr_name = attr_name[:-9]
elif attr_name.endswith("TokenizerFast"):
lookup_dict = fast_tokenizers
attr_name = attr_name[:-13]
elif _re_tf_models.match(attr_name) is not None:
lookup_dict = tf_models
attr_name = _re_tf_models.match(attr_name).groups()[0]
elif _re_flax_models.match(attr_name) is not None:
lookup_dict = flax_models
attr_name = _re_flax_models.match(attr_name).groups()[0]
elif _re_pt_models.match(attr_name) is not None:
lookup_dict = pt_models
attr_name = _re_pt_models.match(attr_name).groups()[0]
# Model names that have a Flax implementation.
has_flax_model = [
name for name, config in model_name_to_config.items() if config in transformers.FLAX_MODEL_MAPPING
]
if lookup_dict is not None:
while len(attr_name) > 0:
if attr_name in model_name_to_prefix.values():
lookup_dict[attr_name] = True
break
# Try again after removing the last word in the name
attr_name = "".join(camel_case_split(attr_name)[:-1])
# Let's build that table!
model_names = list(model_name_to_config.keys())
@@ -374,13 +390,14 @@ def get_model_table_from_auto_modules():
check = {True: "", False: ""}
for name in model_names:
prefix = model_name_to_prefix[name]
line = [
name,
check[name in has_slow_tokenizers],
check[name in has_fast_tokenizers],
check[name in has_pt_model],
check[name in has_tf_model],
check[name in has_flax_model],
check[slow_tokenizers[prefix]],
check[fast_tokenizers[prefix]],
check[pt_models[prefix]],
check[tf_models[prefix]],
check[flax_models[prefix]],
]
table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n"
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"