Make the big table creation/check platform independent (#8856)
This commit is contained in:
@@ -214,6 +214,7 @@ if is_sentencepiece_available():
|
|||||||
from .models.camembert import CamembertTokenizer
|
from .models.camembert import CamembertTokenizer
|
||||||
from .models.marian import MarianTokenizer
|
from .models.marian import MarianTokenizer
|
||||||
from .models.mbart import MBartTokenizer
|
from .models.mbart import MBartTokenizer
|
||||||
|
from .models.mt5 import MT5Tokenizer
|
||||||
from .models.pegasus import PegasusTokenizer
|
from .models.pegasus import PegasusTokenizer
|
||||||
from .models.reformer import ReformerTokenizer
|
from .models.reformer import ReformerTokenizer
|
||||||
from .models.t5 import T5Tokenizer
|
from .models.t5 import T5Tokenizer
|
||||||
@@ -240,6 +241,7 @@ if is_tokenizers_available():
|
|||||||
from .models.lxmert import LxmertTokenizerFast
|
from .models.lxmert import LxmertTokenizerFast
|
||||||
from .models.mbart import MBartTokenizerFast
|
from .models.mbart import MBartTokenizerFast
|
||||||
from .models.mobilebert import MobileBertTokenizerFast
|
from .models.mobilebert import MobileBertTokenizerFast
|
||||||
|
from .models.mt5 import MT5TokenizerFast
|
||||||
from .models.openai import OpenAIGPTTokenizerFast
|
from .models.openai import OpenAIGPTTokenizerFast
|
||||||
from .models.pegasus import PegasusTokenizerFast
|
from .models.pegasus import PegasusTokenizerFast
|
||||||
from .models.reformer import ReformerTokenizerFast
|
from .models.reformer import ReformerTokenizerFast
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ if is_sentencepiece_available():
|
|||||||
from ..camembert.tokenization_camembert import CamembertTokenizer
|
from ..camembert.tokenization_camembert import CamembertTokenizer
|
||||||
from ..marian.tokenization_marian import MarianTokenizer
|
from ..marian.tokenization_marian import MarianTokenizer
|
||||||
from ..mbart.tokenization_mbart import MBartTokenizer
|
from ..mbart.tokenization_mbart import MBartTokenizer
|
||||||
|
from ..mt5 import MT5Tokenizer
|
||||||
from ..pegasus.tokenization_pegasus import PegasusTokenizer
|
from ..pegasus.tokenization_pegasus import PegasusTokenizer
|
||||||
from ..reformer.tokenization_reformer import ReformerTokenizer
|
from ..reformer.tokenization_reformer import ReformerTokenizer
|
||||||
from ..t5.tokenization_t5 import T5Tokenizer
|
from ..t5.tokenization_t5 import T5Tokenizer
|
||||||
@@ -111,6 +112,7 @@ else:
|
|||||||
CamembertTokenizer = None
|
CamembertTokenizer = None
|
||||||
MarianTokenizer = None
|
MarianTokenizer = None
|
||||||
MBartTokenizer = None
|
MBartTokenizer = None
|
||||||
|
MT5Tokenizer = None
|
||||||
PegasusTokenizer = None
|
PegasusTokenizer = None
|
||||||
ReformerTokenizer = None
|
ReformerTokenizer = None
|
||||||
T5Tokenizer = None
|
T5Tokenizer = None
|
||||||
@@ -135,6 +137,7 @@ if is_tokenizers_available():
|
|||||||
from ..lxmert.tokenization_lxmert_fast import LxmertTokenizerFast
|
from ..lxmert.tokenization_lxmert_fast import LxmertTokenizerFast
|
||||||
from ..mbart.tokenization_mbart_fast import MBartTokenizerFast
|
from ..mbart.tokenization_mbart_fast import MBartTokenizerFast
|
||||||
from ..mobilebert.tokenization_mobilebert_fast import MobileBertTokenizerFast
|
from ..mobilebert.tokenization_mobilebert_fast import MobileBertTokenizerFast
|
||||||
|
from ..mt5 import MT5TokenizerFast
|
||||||
from ..openai.tokenization_openai_fast import OpenAIGPTTokenizerFast
|
from ..openai.tokenization_openai_fast import OpenAIGPTTokenizerFast
|
||||||
from ..pegasus.tokenization_pegasus_fast import PegasusTokenizerFast
|
from ..pegasus.tokenization_pegasus_fast import PegasusTokenizerFast
|
||||||
from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast
|
from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast
|
||||||
@@ -161,6 +164,7 @@ else:
|
|||||||
LxmertTokenizerFast = None
|
LxmertTokenizerFast = None
|
||||||
MBartTokenizerFast = None
|
MBartTokenizerFast = None
|
||||||
MobileBertTokenizerFast = None
|
MobileBertTokenizerFast = None
|
||||||
|
MT5TokenizerFast = None
|
||||||
OpenAIGPTTokenizerFast = None
|
OpenAIGPTTokenizerFast = None
|
||||||
PegasusTokenizerFast = None
|
PegasusTokenizerFast = None
|
||||||
ReformerTokenizerFast = None
|
ReformerTokenizerFast = None
|
||||||
@@ -178,7 +182,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
|||||||
[
|
[
|
||||||
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
|
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
|
||||||
(T5Config, (T5Tokenizer, T5TokenizerFast)),
|
(T5Config, (T5Tokenizer, T5TokenizerFast)),
|
||||||
(MT5Config, (T5Tokenizer, T5TokenizerFast)),
|
(MT5Config, (MT5Tokenizer, MT5TokenizerFast)),
|
||||||
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
|
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
|
||||||
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
|
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
|
||||||
(AlbertConfig, (AlbertTokenizer, AlbertTokenizerFast)),
|
(AlbertConfig, (AlbertTokenizer, AlbertTokenizerFast)),
|
||||||
|
|||||||
@@ -2,10 +2,20 @@
|
|||||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||||
# module, but to preserve other warnings. So, don't check this module at all.
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
from ...file_utils import is_tf_available, is_torch_available
|
from ...file_utils import is_sentencepiece_available, is_tf_available, is_tokenizers_available, is_torch_available
|
||||||
from .configuration_mt5 import MT5Config
|
from .configuration_mt5 import MT5Config
|
||||||
|
|
||||||
|
|
||||||
|
if is_sentencepiece_available():
|
||||||
|
from ..t5.tokenization_t5 import T5Tokenizer
|
||||||
|
|
||||||
|
MT5Tokenizer = T5Tokenizer
|
||||||
|
|
||||||
|
if is_tokenizers_available():
|
||||||
|
from ..t5.tokenization_t5_fast import T5TokenizerFast
|
||||||
|
|
||||||
|
MT5TokenizerFast = T5TokenizerFast
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
|
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
|
||||||
|
|
||||||
|
|||||||
@@ -56,6 +56,15 @@ class MBartTokenizer:
|
|||||||
requires_sentencepiece(self)
|
requires_sentencepiece(self)
|
||||||
|
|
||||||
|
|
||||||
|
class MT5Tokenizer:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_sentencepiece(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_sentencepiece(self)
|
||||||
|
|
||||||
|
|
||||||
class PegasusTokenizer:
|
class PegasusTokenizer:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_sentencepiece(self)
|
requires_sentencepiece(self)
|
||||||
|
|||||||
@@ -164,6 +164,15 @@ class MobileBertTokenizerFast:
|
|||||||
requires_tokenizers(self)
|
requires_tokenizers(self)
|
||||||
|
|
||||||
|
|
||||||
|
class MT5TokenizerFast:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_tokenizers(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_tokenizers(self)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPTTokenizerFast:
|
class OpenAIGPTTokenizerFast:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_tokenizers(self)
|
requires_tokenizers(self)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import collections
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
@@ -298,6 +299,22 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Add here suffixes that are used to identify models, seperated by |
|
||||||
|
ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration"
|
||||||
|
# Regexes that match TF/Flax/PT model names.
|
||||||
|
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||||
|
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||||
|
# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes.
|
||||||
|
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||||
|
|
||||||
|
|
||||||
|
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
||||||
|
def camel_case_split(identifier):
|
||||||
|
"Split a camelcased `identifier` into words."
|
||||||
|
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
|
||||||
|
return [m.group(0) for m in matches]
|
||||||
|
|
||||||
|
|
||||||
def _center_text(text, width):
|
def _center_text(text, width):
|
||||||
text_length = 2 if text == "✅" or text == "❌" else len(text)
|
text_length = 2 if text == "✅" or text == "❌" else len(text)
|
||||||
left_indent = (width - text_length) // 2
|
left_indent = (width - text_length) // 2
|
||||||
@@ -319,44 +336,43 @@ def get_model_table_from_auto_modules():
|
|||||||
model_name_to_config = {
|
model_name_to_config = {
|
||||||
name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items()
|
name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items()
|
||||||
}
|
}
|
||||||
# All tokenizer tuples.
|
model_name_to_prefix = {
|
||||||
tokenizers = {
|
name: config.__name__.replace("Config", "") for name, config in model_name_to_config.items()
|
||||||
name: transformers.TOKENIZER_MAPPING[config]
|
|
||||||
for name, config in model_name_to_config.items()
|
|
||||||
if config in transformers.TOKENIZER_MAPPING
|
|
||||||
}
|
}
|
||||||
# Model names that a slow/fast tokenizer.
|
|
||||||
has_slow_tokenizers = [name for name, tok in tokenizers.items() if tok[0] is not None]
|
|
||||||
has_fast_tokenizers = [name for name, tok in tokenizers.items() if tok[1] is not None]
|
|
||||||
|
|
||||||
# Model names that have a PyTorch implementation.
|
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
|
||||||
has_pt_model = [name for name, config in model_name_to_config.items() if config in transformers.MODEL_MAPPING]
|
slow_tokenizers = collections.defaultdict(bool)
|
||||||
# Some of the GenerationModel don't have a base model.
|
fast_tokenizers = collections.defaultdict(bool)
|
||||||
has_pt_model.extend(
|
pt_models = collections.defaultdict(bool)
|
||||||
[
|
tf_models = collections.defaultdict(bool)
|
||||||
name
|
flax_models = collections.defaultdict(bool)
|
||||||
for name, config in model_name_to_config.items()
|
|
||||||
if config in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Special exception for RAG
|
|
||||||
has_pt_model.append("RAG")
|
|
||||||
|
|
||||||
# Model names that have a TensorFlow implementation.
|
# Let's lookup through all transformers object (once).
|
||||||
has_tf_model = [name for name, config in model_name_to_config.items() if config in transformers.TF_MODEL_MAPPING]
|
for attr_name in dir(transformers):
|
||||||
# Some of the GenerationModel don't have a base model.
|
lookup_dict = None
|
||||||
has_tf_model.extend(
|
if attr_name.endswith("Tokenizer"):
|
||||||
[
|
lookup_dict = slow_tokenizers
|
||||||
name
|
attr_name = attr_name[:-9]
|
||||||
for name, config in model_name_to_config.items()
|
elif attr_name.endswith("TokenizerFast"):
|
||||||
if config in transformers.TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
lookup_dict = fast_tokenizers
|
||||||
]
|
attr_name = attr_name[:-13]
|
||||||
)
|
elif _re_tf_models.match(attr_name) is not None:
|
||||||
|
lookup_dict = tf_models
|
||||||
|
attr_name = _re_tf_models.match(attr_name).groups()[0]
|
||||||
|
elif _re_flax_models.match(attr_name) is not None:
|
||||||
|
lookup_dict = flax_models
|
||||||
|
attr_name = _re_flax_models.match(attr_name).groups()[0]
|
||||||
|
elif _re_pt_models.match(attr_name) is not None:
|
||||||
|
lookup_dict = pt_models
|
||||||
|
attr_name = _re_pt_models.match(attr_name).groups()[0]
|
||||||
|
|
||||||
# Model names that have a Flax implementation.
|
if lookup_dict is not None:
|
||||||
has_flax_model = [
|
while len(attr_name) > 0:
|
||||||
name for name, config in model_name_to_config.items() if config in transformers.FLAX_MODEL_MAPPING
|
if attr_name in model_name_to_prefix.values():
|
||||||
]
|
lookup_dict[attr_name] = True
|
||||||
|
break
|
||||||
|
# Try again after removing the last word in the name
|
||||||
|
attr_name = "".join(camel_case_split(attr_name)[:-1])
|
||||||
|
|
||||||
# Let's build that table!
|
# Let's build that table!
|
||||||
model_names = list(model_name_to_config.keys())
|
model_names = list(model_name_to_config.keys())
|
||||||
@@ -374,13 +390,14 @@ def get_model_table_from_auto_modules():
|
|||||||
|
|
||||||
check = {True: "✅", False: "❌"}
|
check = {True: "✅", False: "❌"}
|
||||||
for name in model_names:
|
for name in model_names:
|
||||||
|
prefix = model_name_to_prefix[name]
|
||||||
line = [
|
line = [
|
||||||
name,
|
name,
|
||||||
check[name in has_slow_tokenizers],
|
check[slow_tokenizers[prefix]],
|
||||||
check[name in has_fast_tokenizers],
|
check[fast_tokenizers[prefix]],
|
||||||
check[name in has_pt_model],
|
check[pt_models[prefix]],
|
||||||
check[name in has_tf_model],
|
check[tf_models[prefix]],
|
||||||
check[name in has_flax_model],
|
check[flax_models[prefix]],
|
||||||
]
|
]
|
||||||
table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n"
|
table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n"
|
||||||
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
|
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
|
||||||
|
|||||||
Reference in New Issue
Block a user