Replaced torch.load for loading the pretrained vocab of TransformerXL tokenizer to pickle.load (#6935)

* Replaced torch.load for loading the pretrained vocab of TransformerXL to pickle.load

* Replaced torch.save with pickle.dump when saving the vocabulary

* updating transformer-xl

* uploaded on S3 - compatibility

* fix tests

* style

* Address review comments

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Piero Molino
2020-10-08 01:16:10 -07:00
committed by GitHub
parent aba4e22944
commit 4d04120c6d
2 changed files with 59 additions and 14 deletions

View File

@@ -203,6 +203,19 @@ def is_faiss_available():
return _faiss_available return _faiss_available
def torch_only_method(fn):
def wrapper(*args, **kwargs):
if not _torch_available:
raise ImportError(
"You need to install pytorch to use this method or class, "
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
)
else:
return fn(*args, **kwargs)
return wrapper
def is_sklearn_available(): def is_sklearn_available():
return _has_sklearn return _has_sklearn

View File

@@ -36,7 +36,7 @@ from tokenizers.normalizers import Lowercase, Sequence, Strip, unicode_normalize
from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit
from tokenizers.processors import BertProcessing from tokenizers.processors import BertProcessing
from .file_utils import cached_path, is_torch_available from .file_utils import cached_path, is_torch_available, torch_only_method
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_fast import PreTrainedTokenizerFast from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging from .utils import logging
@@ -48,12 +48,16 @@ if is_torch_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"pretrained_vocab_file": "vocab.bin", "vocab_file": "vocab.txt"} VOCAB_FILES_NAMES = {
"pretrained_vocab_file": "vocab.pkl",
"pretrained_vocab_file_torch": "vocab.bin",
"vocab_file": "vocab.txt",
}
VOCAB_FILES_NAMES_FAST = {"pretrained_vocab_file": "vocab.json", "vocab_file": "vocab.json"} VOCAB_FILES_NAMES_FAST = {"pretrained_vocab_file": "vocab.json", "vocab_file": "vocab.json"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
"pretrained_vocab_file": { "pretrained_vocab_file": {
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin", "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.pkl",
} }
} }
@@ -139,8 +143,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
File containing the vocabulary (from the original implementation). File containing the vocabulary (from the original implementation).
pretrained_vocab_file (:obj:`str`, `optional`): pretrained_vocab_file (:obj:`str`, `optional`):
File containing the vocabulary as saved with the :obj:`save_pretrained()` method. File containing the vocabulary as saved with the :obj:`save_pretrained()` method.
never_split (xxx, `optional`): never_split (:obj:`List[str]`, `optional`):
Fill me with intesting stuff. List of tokens that should never be split. If no list is specified, will simply use the existing
special tokens.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`): unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead. token instead.
@@ -165,7 +170,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
lower_case=False, lower_case=False,
delimiter=None, delimiter=None,
vocab_file=None, vocab_file=None,
pretrained_vocab_file=None, pretrained_vocab_file: str = None,
never_split=None, never_split=None,
unk_token="<unk>", unk_token="<unk>",
eos_token="<eos>", eos_token="<eos>",
@@ -197,23 +202,40 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.moses_tokenizer = sm.MosesTokenizer(language) self.moses_tokenizer = sm.MosesTokenizer(language)
self.moses_detokenizer = sm.MosesDetokenizer(language) self.moses_detokenizer = sm.MosesDetokenizer(language)
# This try... catch... is not beautiful but honestly this tokenizer was not made to be used
# in a library like ours, at all.
try: try:
vocab_dict = None
if pretrained_vocab_file is not None: if pretrained_vocab_file is not None:
# Hack because, honestly this tokenizer was not made to be used # Priority on pickle files (support PyTorch and TF)
# in a library like ours, at all. with open(pretrained_vocab_file, "rb") as f:
vocab_dict = torch.load(pretrained_vocab_file) vocab_dict = pickle.load(f)
# Loading a torch-saved transfo-xl vocab dict with pickle results in an integer
# Entering this if statement means that we tried to load a torch-saved file with pickle, and we failed.
# We therefore load it with torch, if it's available.
if type(vocab_dict) == int:
if not is_torch_available():
raise ImportError(
"Not trying to load dict with PyTorch as you need to install pytorch to load "
"from a PyTorch pretrained vocabulary, "
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
)
vocab_dict = torch.load(pretrained_vocab_file)
if vocab_dict is not None:
for key, value in vocab_dict.items(): for key, value in vocab_dict.items():
if key not in self.__dict__: if key not in self.__dict__:
self.__dict__[key] = value self.__dict__[key] = value
elif vocab_file is not None:
if vocab_file is not None:
self.build_vocab() self.build_vocab()
except Exception:
except Exception as e:
raise ValueError( raise ValueError(
"Unable to parse file {}. Unknown format. " "Unable to parse file {}. Unknown format. "
"If you tried to load a model saved through TransfoXLTokenizerFast," "If you tried to load a model saved through TransfoXLTokenizerFast,"
"please note they are not compatible.".format(pretrained_vocab_file) "please note they are not compatible.".format(pretrained_vocab_file)
) ) from e
if vocab_file is not None: if vocab_file is not None:
self.build_vocab() self.build_vocab()
@@ -286,7 +308,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"]) vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"])
else: else:
vocab_file = vocab_path vocab_file = vocab_path
torch.save(self.__dict__, vocab_file) with open(vocab_file, "wb") as f:
pickle.dump(self.__dict__, f)
return (vocab_file,) return (vocab_file,)
def build_vocab(self): def build_vocab(self):
@@ -309,6 +332,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
logger.info("final vocab size {} from {} unique tokens".format(len(self), len(self.counter))) logger.info("final vocab size {} from {} unique tokens".format(len(self), len(self.counter)))
@torch_only_method
def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False): def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
if verbose: if verbose:
logger.info("encoding file {} ...".format(path)) logger.info("encoding file {} ...".format(path))
@@ -326,6 +350,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
return encoded return encoded
@torch_only_method
def encode_sents(self, sents, ordered=False, verbose=False): def encode_sents(self, sents, ordered=False, verbose=False):
if verbose: if verbose:
logger.info("encoding {} sents ...".format(len(sents))) logger.info("encoding {} sents ...".format(len(sents)))
@@ -436,6 +461,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
out_string = self.moses_detokenizer.detokenize(tokens) out_string = self.moses_detokenizer.detokenize(tokens)
return detokenize_numbers(out_string).strip() return detokenize_numbers(out_string).strip()
@torch_only_method
def convert_to_tensor(self, symbols): def convert_to_tensor(self, symbols):
return torch.LongTensor(self.convert_tokens_to_ids(symbols)) return torch.LongTensor(self.convert_tokens_to_ids(symbols))
@@ -706,6 +732,7 @@ class LMShuffledIterator(object):
for idx in epoch_indices: for idx in epoch_indices:
yield self.data[idx] yield self.data[idx]
@torch_only_method
def stream_iterator(self, sent_stream): def stream_iterator(self, sent_stream):
# streams for each data in the batch # streams for each data in the batch
streams = [None] * self.bsz streams = [None] * self.bsz
@@ -795,6 +822,7 @@ class LMMultiFileIterator(LMShuffledIterator):
class TransfoXLCorpus(object): class TransfoXLCorpus(object):
@classmethod @classmethod
@torch_only_method
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
""" """
Instantiate a pre-processed corpus. Instantiate a pre-processed corpus.
@@ -892,10 +920,14 @@ class TransfoXLCorpus(object):
data_iter = LMOrderedIterator(data, *args, **kwargs) data_iter = LMOrderedIterator(data, *args, **kwargs)
elif self.dataset == "lm1b": elif self.dataset == "lm1b":
data_iter = LMShuffledIterator(data, *args, **kwargs) data_iter = LMShuffledIterator(data, *args, **kwargs)
else:
data_iter = None
raise ValueError(f"Split not recognized: {split}")
return data_iter return data_iter
@torch_only_method
def get_lm_corpus(datadir, dataset): def get_lm_corpus(datadir, dataset):
fn = os.path.join(datadir, "cache.pt") fn = os.path.join(datadir, "cache.pt")
fn_pickle = os.path.join(datadir, "cache.pkl") fn_pickle = os.path.join(datadir, "cache.pkl")