From c0df963ee134a0872f45622a568072424d3ac96b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 1 Dec 2020 11:45:57 -0500 Subject: [PATCH] Make the big table creation/check platform independent (#8856) --- src/transformers/__init__.py | 2 + .../models/auto/tokenization_auto.py | 6 +- src/transformers/models/mt5/__init__.py | 12 ++- .../utils/dummy_sentencepiece_objects.py | 9 ++ .../utils/dummy_tokenizers_objects.py | 9 ++ utils/check_copies.py | 95 +++++++++++-------- 6 files changed, 92 insertions(+), 41 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2ae5ec3825..bacfbb64e1 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -214,6 +214,7 @@ if is_sentencepiece_available(): from .models.camembert import CamembertTokenizer from .models.marian import MarianTokenizer from .models.mbart import MBartTokenizer + from .models.mt5 import MT5Tokenizer from .models.pegasus import PegasusTokenizer from .models.reformer import ReformerTokenizer from .models.t5 import T5Tokenizer @@ -240,6 +241,7 @@ if is_tokenizers_available(): from .models.lxmert import LxmertTokenizerFast from .models.mbart import MBartTokenizerFast from .models.mobilebert import MobileBertTokenizerFast + from .models.mt5 import MT5TokenizerFast from .models.openai import OpenAIGPTTokenizerFast from .models.pegasus import PegasusTokenizerFast from .models.reformer import ReformerTokenizerFast diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 70774c3cb1..db70c75327 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -98,6 +98,7 @@ if is_sentencepiece_available(): from ..camembert.tokenization_camembert import CamembertTokenizer from ..marian.tokenization_marian import MarianTokenizer from ..mbart.tokenization_mbart import MBartTokenizer + from ..mt5 import MT5Tokenizer from ..pegasus.tokenization_pegasus import PegasusTokenizer from ..reformer.tokenization_reformer import ReformerTokenizer from ..t5.tokenization_t5 import T5Tokenizer @@ -111,6 +112,7 @@ else: CamembertTokenizer = None MarianTokenizer = None MBartTokenizer = None + MT5Tokenizer = None PegasusTokenizer = None ReformerTokenizer = None T5Tokenizer = None @@ -135,6 +137,7 @@ if is_tokenizers_available(): from ..lxmert.tokenization_lxmert_fast import LxmertTokenizerFast from ..mbart.tokenization_mbart_fast import MBartTokenizerFast from ..mobilebert.tokenization_mobilebert_fast import MobileBertTokenizerFast + from ..mt5 import MT5TokenizerFast from ..openai.tokenization_openai_fast import OpenAIGPTTokenizerFast from ..pegasus.tokenization_pegasus_fast import PegasusTokenizerFast from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast @@ -161,6 +164,7 @@ else: LxmertTokenizerFast = None MBartTokenizerFast = None MobileBertTokenizerFast = None + MT5TokenizerFast = None OpenAIGPTTokenizerFast = None PegasusTokenizerFast = None ReformerTokenizerFast = None @@ -178,7 +182,7 @@ TOKENIZER_MAPPING = OrderedDict( [ (RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)), (T5Config, (T5Tokenizer, T5TokenizerFast)), - (MT5Config, (T5Tokenizer, T5TokenizerFast)), + (MT5Config, (MT5Tokenizer, MT5TokenizerFast)), (MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)), (DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)), (AlbertConfig, (AlbertTokenizer, AlbertTokenizerFast)), diff --git a/src/transformers/models/mt5/__init__.py b/src/transformers/models/mt5/__init__.py index 837a32966d..9cb90521c9 100644 --- a/src/transformers/models/mt5/__init__.py +++ b/src/transformers/models/mt5/__init__.py @@ -2,10 +2,20 @@ # There's no way to ignore "F401 '...' imported but unused" warnings in this # module, but to preserve other warnings. So, don't check this module at all. -from ...file_utils import is_tf_available, is_torch_available +from ...file_utils import is_sentencepiece_available, is_tf_available, is_tokenizers_available, is_torch_available from .configuration_mt5 import MT5Config +if is_sentencepiece_available(): + from ..t5.tokenization_t5 import T5Tokenizer + + MT5Tokenizer = T5Tokenizer + +if is_tokenizers_available(): + from ..t5.tokenization_t5_fast import T5TokenizerFast + + MT5TokenizerFast = T5TokenizerFast + if is_torch_available(): from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model diff --git a/src/transformers/utils/dummy_sentencepiece_objects.py b/src/transformers/utils/dummy_sentencepiece_objects.py index b78c274de6..bbaddd2cc9 100644 --- a/src/transformers/utils/dummy_sentencepiece_objects.py +++ b/src/transformers/utils/dummy_sentencepiece_objects.py @@ -56,6 +56,15 @@ class MBartTokenizer: requires_sentencepiece(self) +class MT5Tokenizer: + def __init__(self, *args, **kwargs): + requires_sentencepiece(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_sentencepiece(self) + + class PegasusTokenizer: def __init__(self, *args, **kwargs): requires_sentencepiece(self) diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py index 5592f28e41..5c105c9342 100644 --- a/src/transformers/utils/dummy_tokenizers_objects.py +++ b/src/transformers/utils/dummy_tokenizers_objects.py @@ -164,6 +164,15 @@ class MobileBertTokenizerFast: requires_tokenizers(self) +class MT5TokenizerFast: + def __init__(self, *args, **kwargs): + requires_tokenizers(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_tokenizers(self) + + class OpenAIGPTTokenizerFast: def __init__(self, *args, **kwargs): requires_tokenizers(self) diff --git a/utils/check_copies.py b/utils/check_copies.py index 7cf89b987b..b52b658ec0 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -14,6 +14,7 @@ # limitations under the License. import argparse +import collections import glob import importlib import os @@ -298,6 +299,22 @@ def check_model_list_copy(overwrite=False, max_per_line=119): ) +# Add here suffixes that are used to identify models, seperated by | +ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration" +# Regexes that match TF/Flax/PT model names. +_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") +_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") +# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes. +_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") + + +# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python +def camel_case_split(identifier): + "Split a camelcased `identifier` into words." + matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier) + return [m.group(0) for m in matches] + + def _center_text(text, width): text_length = 2 if text == "✅" or text == "❌" else len(text) left_indent = (width - text_length) // 2 @@ -319,44 +336,43 @@ def get_model_table_from_auto_modules(): model_name_to_config = { name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items() } - # All tokenizer tuples. - tokenizers = { - name: transformers.TOKENIZER_MAPPING[config] - for name, config in model_name_to_config.items() - if config in transformers.TOKENIZER_MAPPING + model_name_to_prefix = { + name: config.__name__.replace("Config", "") for name, config in model_name_to_config.items() } - # Model names that a slow/fast tokenizer. - has_slow_tokenizers = [name for name, tok in tokenizers.items() if tok[0] is not None] - has_fast_tokenizers = [name for name, tok in tokenizers.items() if tok[1] is not None] - # Model names that have a PyTorch implementation. - has_pt_model = [name for name, config in model_name_to_config.items() if config in transformers.MODEL_MAPPING] - # Some of the GenerationModel don't have a base model. - has_pt_model.extend( - [ - name - for name, config in model_name_to_config.items() - if config in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING - ] - ) - # Special exception for RAG - has_pt_model.append("RAG") + # Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax. + slow_tokenizers = collections.defaultdict(bool) + fast_tokenizers = collections.defaultdict(bool) + pt_models = collections.defaultdict(bool) + tf_models = collections.defaultdict(bool) + flax_models = collections.defaultdict(bool) - # Model names that have a TensorFlow implementation. - has_tf_model = [name for name, config in model_name_to_config.items() if config in transformers.TF_MODEL_MAPPING] - # Some of the GenerationModel don't have a base model. - has_tf_model.extend( - [ - name - for name, config in model_name_to_config.items() - if config in transformers.TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING - ] - ) + # Let's lookup through all transformers object (once). + for attr_name in dir(transformers): + lookup_dict = None + if attr_name.endswith("Tokenizer"): + lookup_dict = slow_tokenizers + attr_name = attr_name[:-9] + elif attr_name.endswith("TokenizerFast"): + lookup_dict = fast_tokenizers + attr_name = attr_name[:-13] + elif _re_tf_models.match(attr_name) is not None: + lookup_dict = tf_models + attr_name = _re_tf_models.match(attr_name).groups()[0] + elif _re_flax_models.match(attr_name) is not None: + lookup_dict = flax_models + attr_name = _re_flax_models.match(attr_name).groups()[0] + elif _re_pt_models.match(attr_name) is not None: + lookup_dict = pt_models + attr_name = _re_pt_models.match(attr_name).groups()[0] - # Model names that have a Flax implementation. - has_flax_model = [ - name for name, config in model_name_to_config.items() if config in transformers.FLAX_MODEL_MAPPING - ] + if lookup_dict is not None: + while len(attr_name) > 0: + if attr_name in model_name_to_prefix.values(): + lookup_dict[attr_name] = True + break + # Try again after removing the last word in the name + attr_name = "".join(camel_case_split(attr_name)[:-1]) # Let's build that table! model_names = list(model_name_to_config.keys()) @@ -374,13 +390,14 @@ def get_model_table_from_auto_modules(): check = {True: "✅", False: "❌"} for name in model_names: + prefix = model_name_to_prefix[name] line = [ name, - check[name in has_slow_tokenizers], - check[name in has_fast_tokenizers], - check[name in has_pt_model], - check[name in has_tf_model], - check[name in has_flax_model], + check[slow_tokenizers[prefix]], + check[fast_tokenizers[prefix]], + check[pt_models[prefix]], + check[tf_models[prefix]], + check[flax_models[prefix]], ] table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n" table += "+" + "+".join(["-" * w for w in widths]) + "+\n"