From 4d04120c6d76ed0ddf6525dd60c1211f4afffb2f Mon Sep 17 00:00:00 2001 From: Piero Molino Date: Thu, 8 Oct 2020 01:16:10 -0700 Subject: [PATCH] Replaced torch.load for loading the pretrained vocab of TransformerXL tokenizer to pickle.load (#6935) * Replaced torch.load for loading the pretrained vocab of TransformerXL to pickle.load * Replaced torch.save with pickle.dump when saving the vocabulary * updating transformer-xl * uploaded on S3 - compatibility * fix tests * style * Address review comments Co-authored-by: Thomas Wolf Co-authored-by: Lysandre --- src/transformers/file_utils.py | 13 +++++ src/transformers/tokenization_transfo_xl.py | 60 ++++++++++++++++----- 2 files changed, 59 insertions(+), 14 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 6314ea600a..d79072b361 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -203,6 +203,19 @@ def is_faiss_available(): return _faiss_available +def torch_only_method(fn): + def wrapper(*args, **kwargs): + if not _torch_available: + raise ImportError( + "You need to install pytorch to use this method or class, " + "or activate it with environment variables USE_TORCH=1 and USE_TF=0." + ) + else: + return fn(*args, **kwargs) + + return wrapper + + def is_sklearn_available(): return _has_sklearn diff --git a/src/transformers/tokenization_transfo_xl.py b/src/transformers/tokenization_transfo_xl.py index 08c454d718..b6f34d43da 100644 --- a/src/transformers/tokenization_transfo_xl.py +++ b/src/transformers/tokenization_transfo_xl.py @@ -36,7 +36,7 @@ from tokenizers.normalizers import Lowercase, Sequence, Strip, unicode_normalize from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit from tokenizers.processors import BertProcessing -from .file_utils import cached_path, is_torch_available +from .file_utils import cached_path, is_torch_available, torch_only_method from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils_fast import PreTrainedTokenizerFast from .utils import logging @@ -48,12 +48,16 @@ if is_torch_available(): logger = logging.get_logger(__name__) -VOCAB_FILES_NAMES = {"pretrained_vocab_file": "vocab.bin", "vocab_file": "vocab.txt"} +VOCAB_FILES_NAMES = { + "pretrained_vocab_file": "vocab.pkl", + "pretrained_vocab_file_torch": "vocab.bin", + "vocab_file": "vocab.txt", +} VOCAB_FILES_NAMES_FAST = {"pretrained_vocab_file": "vocab.json", "vocab_file": "vocab.json"} PRETRAINED_VOCAB_FILES_MAP = { "pretrained_vocab_file": { - "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin", + "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.pkl", } } @@ -139,8 +143,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer): File containing the vocabulary (from the original implementation). pretrained_vocab_file (:obj:`str`, `optional`): File containing the vocabulary as saved with the :obj:`save_pretrained()` method. - never_split (xxx, `optional`): - Fill me with intesting stuff. + never_split (:obj:`List[str]`, `optional`): + List of tokens that should never be split. If no list is specified, will simply use the existing + special tokens. unk_token (:obj:`str`, `optional`, defaults to :obj:`""`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. @@ -165,7 +170,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): lower_case=False, delimiter=None, vocab_file=None, - pretrained_vocab_file=None, + pretrained_vocab_file: str = None, never_split=None, unk_token="", eos_token="", @@ -197,23 +202,40 @@ class TransfoXLTokenizer(PreTrainedTokenizer): self.moses_tokenizer = sm.MosesTokenizer(language) self.moses_detokenizer = sm.MosesDetokenizer(language) + # This try... catch... is not beautiful but honestly this tokenizer was not made to be used + # in a library like ours, at all. try: + vocab_dict = None if pretrained_vocab_file is not None: - # Hack because, honestly this tokenizer was not made to be used - # in a library like ours, at all. - vocab_dict = torch.load(pretrained_vocab_file) + # Priority on pickle files (support PyTorch and TF) + with open(pretrained_vocab_file, "rb") as f: + vocab_dict = pickle.load(f) + + # Loading a torch-saved transfo-xl vocab dict with pickle results in an integer + # Entering this if statement means that we tried to load a torch-saved file with pickle, and we failed. + # We therefore load it with torch, if it's available. + if type(vocab_dict) == int: + if not is_torch_available(): + raise ImportError( + "Not trying to load dict with PyTorch as you need to install pytorch to load " + "from a PyTorch pretrained vocabulary, " + "or activate it with environment variables USE_TORCH=1 and USE_TF=0." + ) + vocab_dict = torch.load(pretrained_vocab_file) + + if vocab_dict is not None: for key, value in vocab_dict.items(): if key not in self.__dict__: self.__dict__[key] = value - - if vocab_file is not None: + elif vocab_file is not None: self.build_vocab() - except Exception: + + except Exception as e: raise ValueError( "Unable to parse file {}. Unknown format. " "If you tried to load a model saved through TransfoXLTokenizerFast," "please note they are not compatible.".format(pretrained_vocab_file) - ) + ) from e if vocab_file is not None: self.build_vocab() @@ -286,7 +308,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer): vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"]) else: vocab_file = vocab_path - torch.save(self.__dict__, vocab_file) + with open(vocab_file, "wb") as f: + pickle.dump(self.__dict__, f) return (vocab_file,) def build_vocab(self): @@ -309,6 +332,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): logger.info("final vocab size {} from {} unique tokens".format(len(self), len(self.counter))) + @torch_only_method def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False): if verbose: logger.info("encoding file {} ...".format(path)) @@ -326,6 +350,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): return encoded + @torch_only_method def encode_sents(self, sents, ordered=False, verbose=False): if verbose: logger.info("encoding {} sents ...".format(len(sents))) @@ -436,6 +461,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): out_string = self.moses_detokenizer.detokenize(tokens) return detokenize_numbers(out_string).strip() + @torch_only_method def convert_to_tensor(self, symbols): return torch.LongTensor(self.convert_tokens_to_ids(symbols)) @@ -706,6 +732,7 @@ class LMShuffledIterator(object): for idx in epoch_indices: yield self.data[idx] + @torch_only_method def stream_iterator(self, sent_stream): # streams for each data in the batch streams = [None] * self.bsz @@ -795,6 +822,7 @@ class LMMultiFileIterator(LMShuffledIterator): class TransfoXLCorpus(object): @classmethod + @torch_only_method def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): """ Instantiate a pre-processed corpus. @@ -892,10 +920,14 @@ class TransfoXLCorpus(object): data_iter = LMOrderedIterator(data, *args, **kwargs) elif self.dataset == "lm1b": data_iter = LMShuffledIterator(data, *args, **kwargs) + else: + data_iter = None + raise ValueError(f"Split not recognized: {split}") return data_iter +@torch_only_method def get_lm_corpus(datadir, dataset): fn = os.path.join(datadir, "cache.pt") fn_pickle = os.path.join(datadir, "cache.pkl")