From 88111de07c40797aaca619be693616c3c4cda4bd Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Aug 2019 16:55:48 +0200 Subject: [PATCH] saving and reloading tokenizer configurations --- pytorch_transformers/tokenization_utils.py | 54 ++++++++++++++++++---- 1 file changed, 45 insertions(+), 9 deletions(-) diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 19b37da8c8..51e59fe46c 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -20,6 +20,7 @@ import logging import os import json import six +import copy from io import open from .file_utils import cached_path @@ -28,6 +29,7 @@ logger = logging.getLogger(__name__) SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json' ADDED_TOKENS_FILE = 'added_tokens.json' +TOKENIZER_CONFIG_FILE = 'tokenizer_config.json' class PreTrainedTokenizer(object): """ Base class for all tokenizers. @@ -168,9 +170,15 @@ class PreTrainedTokenizer(object): self._additional_special_tokens = [] self.max_len = max_len if max_len is not None else int(1e12) + + # Added tokens self.added_tokens_encoder = {} self.added_tokens_decoder = {} + # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) + self.init_inputs = () + self.init_kwargs = {} + for key, value in kwargs.items(): if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key == 'additional_special_tokens': @@ -230,7 +238,7 @@ class PreTrainedTokenizer(object): @classmethod - def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): cache_dir = kwargs.pop('cache_dir', None) force_download = kwargs.pop('force_download', False) proxies = kwargs.pop('proxies', None) @@ -266,15 +274,17 @@ class PreTrainedTokenizer(object): vocab_files[file_id] = full_file_name # Look for the additional tokens files - all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE, - 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE} + additional_files_names = {'added_tokens_file': ADDED_TOKENS_FILE, + 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE, + 'tokenizer_config_file': TOKENIZER_CONFIG_FILE, + } # If a path to a file was provided, get the parent directory saved_directory = pretrained_model_name_or_path if os.path.exists(saved_directory) and not os.path.isdir(saved_directory): saved_directory = os.path.dirname(saved_directory) - for file_id, file_name in all_vocab_files_names.items(): + for file_id, file_name in additional_files_names.items(): full_file_name = os.path.join(saved_directory, file_name) if not os.path.exists(full_file_name): logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) @@ -317,8 +327,18 @@ class PreTrainedTokenizer(object): logger.info("loading file {} from cache at {}".format( file_path, resolved_vocab_files[file_id])) - # Prepare initialization kwargs - init_kwargs = init_configuration + # Prepare tokenizer initialization kwargs + # Did we saved some inputs and kwargs to reload ? + tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None) + if tokenizer_config_file is not None: + init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8")) + saved_init_inputs = init_kwargs.pop('init_inputs', []) + if not init_inputs: + init_inputs = saved_init_inputs + else: + init_kwargs = init_configuration + + # Update with newly provided kwargs init_kwargs.update(kwargs) # Set max length if needed @@ -342,7 +362,11 @@ class PreTrainedTokenizer(object): init_kwargs[key] = value # Instantiate tokenizer. - tokenizer = cls(*inputs, **init_kwargs) + tokenizer = cls(*init_inputs, **init_kwargs) + + # Save inputs and kwargs for saving and re-loading with ``save_pretrained`` + tokenizer.init_inputs = init_inputs + tokenizer.init_kwargs = init_kwargs # Add supplementary tokens. if added_tokens_file is not None: @@ -355,8 +379,13 @@ class PreTrainedTokenizer(object): def save_pretrained(self, save_directory): - """ Save the tokenizer vocabulary files (with added tokens) and the - special-tokens-to-class-attributes-mapping to a directory. + """ Save the tokenizer vocabulary files together with: + - added tokens, + - special-tokens-to-class-attributes-mapping, + - tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert). + + This won't save modifications other than (added tokens and special token mapping) you may have + applied to the tokenizer after the instantion (e.g. modifying tokenizer.do_lower_case after creation). This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method. """ @@ -366,6 +395,13 @@ class PreTrainedTokenizer(object): special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE) added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE) + tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE) + + tokenizer_config = copy.deepcopy(self.init_kwargs) + tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs) + + with open(tokenizer_config_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(tokenizer_config, ensure_ascii=False)) with open(special_tokens_map_file, 'w', encoding='utf-8') as f: f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))