saving and reloading tokenizer configurations
This commit is contained in:
@@ -20,6 +20,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import six
|
import six
|
||||||
|
import copy
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
@@ -28,6 +29,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
|
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
|
||||||
ADDED_TOKENS_FILE = 'added_tokens.json'
|
ADDED_TOKENS_FILE = 'added_tokens.json'
|
||||||
|
TOKENIZER_CONFIG_FILE = 'tokenizer_config.json'
|
||||||
|
|
||||||
class PreTrainedTokenizer(object):
|
class PreTrainedTokenizer(object):
|
||||||
""" Base class for all tokenizers.
|
""" Base class for all tokenizers.
|
||||||
@@ -168,9 +170,15 @@ class PreTrainedTokenizer(object):
|
|||||||
self._additional_special_tokens = []
|
self._additional_special_tokens = []
|
||||||
|
|
||||||
self.max_len = max_len if max_len is not None else int(1e12)
|
self.max_len = max_len if max_len is not None else int(1e12)
|
||||||
|
|
||||||
|
# Added tokens
|
||||||
self.added_tokens_encoder = {}
|
self.added_tokens_encoder = {}
|
||||||
self.added_tokens_decoder = {}
|
self.added_tokens_decoder = {}
|
||||||
|
|
||||||
|
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
|
||||||
|
self.init_inputs = ()
|
||||||
|
self.init_kwargs = {}
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
|
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||||
if key == 'additional_special_tokens':
|
if key == 'additional_special_tokens':
|
||||||
@@ -230,7 +238,7 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
|
||||||
cache_dir = kwargs.pop('cache_dir', None)
|
cache_dir = kwargs.pop('cache_dir', None)
|
||||||
force_download = kwargs.pop('force_download', False)
|
force_download = kwargs.pop('force_download', False)
|
||||||
proxies = kwargs.pop('proxies', None)
|
proxies = kwargs.pop('proxies', None)
|
||||||
@@ -266,15 +274,17 @@ class PreTrainedTokenizer(object):
|
|||||||
vocab_files[file_id] = full_file_name
|
vocab_files[file_id] = full_file_name
|
||||||
|
|
||||||
# Look for the additional tokens files
|
# Look for the additional tokens files
|
||||||
all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
|
additional_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
|
||||||
'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE}
|
'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE,
|
||||||
|
'tokenizer_config_file': TOKENIZER_CONFIG_FILE,
|
||||||
|
}
|
||||||
|
|
||||||
# If a path to a file was provided, get the parent directory
|
# If a path to a file was provided, get the parent directory
|
||||||
saved_directory = pretrained_model_name_or_path
|
saved_directory = pretrained_model_name_or_path
|
||||||
if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
|
if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
|
||||||
saved_directory = os.path.dirname(saved_directory)
|
saved_directory = os.path.dirname(saved_directory)
|
||||||
|
|
||||||
for file_id, file_name in all_vocab_files_names.items():
|
for file_id, file_name in additional_files_names.items():
|
||||||
full_file_name = os.path.join(saved_directory, file_name)
|
full_file_name = os.path.join(saved_directory, file_name)
|
||||||
if not os.path.exists(full_file_name):
|
if not os.path.exists(full_file_name):
|
||||||
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
||||||
@@ -317,8 +327,18 @@ class PreTrainedTokenizer(object):
|
|||||||
logger.info("loading file {} from cache at {}".format(
|
logger.info("loading file {} from cache at {}".format(
|
||||||
file_path, resolved_vocab_files[file_id]))
|
file_path, resolved_vocab_files[file_id]))
|
||||||
|
|
||||||
# Prepare initialization kwargs
|
# Prepare tokenizer initialization kwargs
|
||||||
|
# Did we saved some inputs and kwargs to reload ?
|
||||||
|
tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None)
|
||||||
|
if tokenizer_config_file is not None:
|
||||||
|
init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8"))
|
||||||
|
saved_init_inputs = init_kwargs.pop('init_inputs', [])
|
||||||
|
if not init_inputs:
|
||||||
|
init_inputs = saved_init_inputs
|
||||||
|
else:
|
||||||
init_kwargs = init_configuration
|
init_kwargs = init_configuration
|
||||||
|
|
||||||
|
# Update with newly provided kwargs
|
||||||
init_kwargs.update(kwargs)
|
init_kwargs.update(kwargs)
|
||||||
|
|
||||||
# Set max length if needed
|
# Set max length if needed
|
||||||
@@ -342,7 +362,11 @@ class PreTrainedTokenizer(object):
|
|||||||
init_kwargs[key] = value
|
init_kwargs[key] = value
|
||||||
|
|
||||||
# Instantiate tokenizer.
|
# Instantiate tokenizer.
|
||||||
tokenizer = cls(*inputs, **init_kwargs)
|
tokenizer = cls(*init_inputs, **init_kwargs)
|
||||||
|
|
||||||
|
# Save inputs and kwargs for saving and re-loading with ``save_pretrained``
|
||||||
|
tokenizer.init_inputs = init_inputs
|
||||||
|
tokenizer.init_kwargs = init_kwargs
|
||||||
|
|
||||||
# Add supplementary tokens.
|
# Add supplementary tokens.
|
||||||
if added_tokens_file is not None:
|
if added_tokens_file is not None:
|
||||||
@@ -355,8 +379,13 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
|
|
||||||
def save_pretrained(self, save_directory):
|
def save_pretrained(self, save_directory):
|
||||||
""" Save the tokenizer vocabulary files (with added tokens) and the
|
""" Save the tokenizer vocabulary files together with:
|
||||||
special-tokens-to-class-attributes-mapping to a directory.
|
- added tokens,
|
||||||
|
- special-tokens-to-class-attributes-mapping,
|
||||||
|
- tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert).
|
||||||
|
|
||||||
|
This won't save modifications other than (added tokens and special token mapping) you may have
|
||||||
|
applied to the tokenizer after the instantion (e.g. modifying tokenizer.do_lower_case after creation).
|
||||||
|
|
||||||
This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method.
|
This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method.
|
||||||
"""
|
"""
|
||||||
@@ -366,6 +395,13 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
|
special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
|
||||||
added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
|
added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
|
||||||
|
tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE)
|
||||||
|
|
||||||
|
tokenizer_config = copy.deepcopy(self.init_kwargs)
|
||||||
|
tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs)
|
||||||
|
|
||||||
|
with open(tokenizer_config_file, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
|
||||||
|
|
||||||
with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
|
with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
|
||||||
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
|
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
|
||||||
|
|||||||
Reference in New Issue
Block a user