From 82462c5cba0ec07a3eeb1e9455d229ceaf43b5f2 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Aug 2019 15:30:41 +0200 Subject: [PATCH] Added option to setup pretrained tokenizer arguments --- pytorch_transformers/tokenization_bert.py | 36 +++--- pytorch_transformers/tokenization_utils.py | 23 ++-- pytorch_transformers/tokenization_xlm.py | 135 +++++++++++++++++++-- 3 files changed, 159 insertions(+), 35 deletions(-) diff --git a/pytorch_transformers/tokenization_bert.py b/pytorch_transformers/tokenization_bert.py index 04f35aa466..d1ace940f0 100644 --- a/pytorch_transformers/tokenization_bert.py +++ b/pytorch_transformers/tokenization_bert.py @@ -63,6 +63,23 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'bert-base-cased-finetuned-mrpc': 512, } +PRETRAINED_INIT_CONFIGURATION = { + 'bert-base-uncased': {'do_lower_case': True}, + 'bert-large-uncased': {'do_lower_case': True}, + 'bert-base-cased': {'do_lower_case': False}, + 'bert-large-cased': {'do_lower_case': False}, + 'bert-base-multilingual-uncased': {'do_lower_case': True}, + 'bert-base-multilingual-cased': {'do_lower_case': False}, + 'bert-base-chinese': {'do_lower_case': False}, + 'bert-base-german-cased': {'do_lower_case': False}, + 'bert-large-uncased-whole-word-masking': {'do_lower_case': True}, + 'bert-large-cased-whole-word-masking': {'do_lower_case': False}, + 'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True}, + 'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False}, + 'bert-base-cased-finetuned-mrpc': {'do_lower_case': False}, +} + + def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() @@ -100,6 +117,7 @@ class BertTokenizer(PreTrainedTokenizer): vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, @@ -199,24 +217,6 @@ class BertTokenizer(PreTrainedTokenizer): index += 1 return (vocab_file,) - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): - """ Instantiate a BertTokenizer from pre-trained vocabulary files. - """ - if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES: - if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): - logger.warning("The pre-trained model you are loading is a cased model but you have not set " - "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " - "you may want to check this behavior.") - kwargs['do_lower_case'] = False - elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): - logger.warning("The pre-trained model you are loading is an uncased model but you have set " - "`do_lower_case` to False. We are setting `do_lower_case=True` for you " - "but you may want to check this behavior.") - kwargs['do_lower_case'] = True - - return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - class BasicTokenizer(object): """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 1d05441593..19b37da8c8 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -40,6 +40,7 @@ class PreTrainedTokenizer(object): - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string). - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file. - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size. + - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method. Parameters: @@ -61,6 +62,7 @@ class PreTrainedTokenizer(object): """ vocab_files_names = {} pretrained_vocab_files_map = {} + pretrained_init_configuration = {} max_model_input_sizes = {} SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token", @@ -235,10 +237,13 @@ class PreTrainedTokenizer(object): s3_models = list(cls.max_model_input_sizes.keys()) vocab_files = {} + init_configuration = {} if pretrained_model_name_or_path in s3_models: # Get the vocabulary from AWS S3 bucket for file_id, map_list in cls.pretrained_vocab_files_map.items(): vocab_files[file_id] = map_list[pretrained_model_name_or_path] + if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration: + init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path] else: # Get the vocabulary from local files logger.info( @@ -312,28 +317,32 @@ class PreTrainedTokenizer(object): logger.info("loading file {} from cache at {}".format( file_path, resolved_vocab_files[file_id])) + # Prepare initialization kwargs + init_kwargs = init_configuration + init_kwargs.update(kwargs) + # Set max length if needed if pretrained_model_name_or_path in cls.max_model_input_sizes: # if we're using a pretrained model, ensure the tokenizer # wont index sequences longer than the number of positional embeddings max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] if max_len is not None and isinstance(max_len, (int, float)): - kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len) - # Merge resolved_vocab_files arguments in kwargs. + # Merge resolved_vocab_files arguments in init_kwargs. added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None) special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None) for args_name, file_path in resolved_vocab_files.items(): - if args_name not in kwargs: - kwargs[args_name] = file_path + if args_name not in init_kwargs: + init_kwargs[args_name] = file_path if special_tokens_map_file is not None: special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8")) for key, value in special_tokens_map.items(): - if key not in kwargs: - kwargs[key] = value + if key not in init_kwargs: + init_kwargs[key] = value # Instantiate tokenizer. - tokenizer = cls(*inputs, **kwargs) + tokenizer = cls(*inputs, **init_kwargs) # Add supplementary tokens. if added_tokens_file is not None: diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index 71bf119387..c40d4cd16e 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -47,7 +47,9 @@ PRETRAINED_VOCAB_FILES_MAP = { 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json", - }, + 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json", + 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json", + } 'merges_file': { 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt", @@ -58,6 +60,8 @@ PRETRAINED_VOCAB_FILES_MAP = { 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt", + 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt", + 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt", }, } @@ -70,6 +74,101 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'xlm-mlm-xnli15-1024': 512, 'xlm-clm-enfr-1024': 512, 'xlm-clm-ende-1024': 512, + 'xlm-mlm-17-1280': 512, + 'xlm-mlm-100-1280': 512, +} + +PRETRAINED_INIT_CONFIGURATION = { + 'xlm-mlm-en-2048': {"do_lowercase_and_remove_accent": True}, + 'xlm-mlm-ende-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "de", + "1": "en"}, + "lang2id": { "de": 0, + "en": 1 }}, + 'xlm-mlm-enfr-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "en", + "1": "fr"}, + "lang2id": { "en": 0, + "fr": 1 }}, + 'xlm-mlm-enro-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "en", + "1": "ro"}, + "lang2id": { "en": 0, + "ro": 1 }}, + 'xlm-mlm-tlm-xnli15-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "ar", + "1": "bg", + "2": "de", + "3": "el", + "4": "en", + "5": "es", + "6": "fr", + "7": "hi", + "8": "ru", + "9": "sw", + "10": "th", + "11": "tr", + "12": "ur", + "13": "vi", + "14": "zh"}, + "lang2id": { "ar": 0, + "bg": 1, + "de": 2, + "el": 3, + "en": 4, + "es": 5, + "fr": 6, + "hi": 7, + "ru": 8, + "sw": 9, + "th": 10, + "tr": 11, + "ur": 12, + "vi": 13, + "zh": 14 }}, + 'xlm-mlm-xnli15-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "ar", + "1": "bg", + "2": "de", + "3": "el", + "4": "en", + "5": "es", + "6": "fr", + "7": "hi", + "8": "ru", + "9": "sw", + "10": "th", + "11": "tr", + "12": "ur", + "13": "vi", + "14": "zh"}, + "lang2id": { "ar": 0, + "bg": 1, + "de": 2, + "el": 3, + "en": 4, + "es": 5, + "fr": 6, + "hi": 7, + "ru": 8, + "sw": 9, + "th": 10, + "tr": 11, + "ur": 12, + "vi": 13, + "zh": 14 }}, + 'xlm-clm-enfr-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "en", + "1": "fr"}, + "lang2id": { "en": 0, + "fr": 1 }}, + 'xlm-clm-ende-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "de", + "1": "en"}, + "lang2id": { "de": 0, + "en": 1 }}, + 'xlm-mlm-17-1280': {"do_lowercase_and_remove_accent": False}, + 'xlm-mlm-100-1280': {"do_lowercase_and_remove_accent": False}, } def get_pairs(word): @@ -183,17 +282,26 @@ class XLMTokenizer(PreTrainedTokenizer): - (optionally) lower case & normalize all inputs text - argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \ - (ex: "__classify__") to a vocabulary. + (ex: "__classify__") to a vocabulary + + - `lang2id` attribute maps the languages supported by the model with their ids if provided (automatically set for pretrained vocabularies) + + - `id2lang` attributes does reverse mapping if provided (automatically set for pretrained vocabularies) + + - `do_lowercase_and_remove_accent` controle lower casing and accent (automatically set for pretrained vocabularies) """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES def __init__(self, vocab_file, merges_file, unk_token="", bos_token="", sep_token="", pad_token="", cls_token="", mask_token="", additional_special_tokens=["", "", "", "", "", "", - "", "", "", ""], **kwargs): + "", "", "", ""], + lang2id=None, id2lang=None, do_lowercase_and_remove_accent=True, + **kwargs): super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, @@ -206,7 +314,12 @@ class XLMTokenizer(PreTrainedTokenizer): self.cache_moses_tokenizer = dict() self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja']) # True for current supported model (v1.2.0), False for XLM-17 & 100 - self.do_lowercase_and_remove_accent = True + self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent + self.lang2id = lang2id + self.id2lang = id2lang + if lang2id is not None and id2lang is not None: + assert len(lang2id) == len(id2lang) + self.ja_word_tokenizer = None self.zh_word_tokenizer = None @@ -244,14 +357,14 @@ class XLMTokenizer(PreTrainedTokenizer): try: import Mykytea self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~')) - except: + except (AttributeError, ImportError) as e: logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps") logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea") logger.error("2. autoreconf -i") logger.error("3. ./configure --prefix=$HOME/local") logger.error("4. make && make install") logger.error("5. pip install kytea") - import sys; sys.exit() + raise e return list(self.ja_word_tokenizer.getWS(text)) @property @@ -336,6 +449,8 @@ class XLMTokenizer(PreTrainedTokenizer): Returns: List of tokens. """ + if lang and self.lang2id and lang not in self.lang2id: + logger.error("Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model.") if bypass_tokenizer: text = text.split() elif lang not in self.lang_with_custom_tokenizer: @@ -349,19 +464,19 @@ class XLMTokenizer(PreTrainedTokenizer): try: if 'pythainlp' not in sys.modules: from pythainlp.tokenize import word_tokenize as th_word_tokenize - except: + except (AttributeError, ImportError) as e: logger.error("Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps") logger.error("1. pip install pythainlp") - import sys; sys.exit() + raise e text = th_word_tokenize(text) elif lang == 'zh': try: if 'jieba' not in sys.modules: import jieba - except: + except (AttributeError, ImportError) as e: logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps") logger.error("1. pip install jieba") - import sys; sys.exit() + raise e text = ' '.join(jieba.cut(text)) text = self.moses_pipeline(text, lang=lang) text = text.split()