Replaced torch.load for loading the pretrained vocab of TransformerXL tokenizer to pickle.load (#6935)
* Replaced torch.load for loading the pretrained vocab of TransformerXL to pickle.load * Replaced torch.save with pickle.dump when saving the vocabulary * updating transformer-xl * uploaded on S3 - compatibility * fix tests * style * Address review comments Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -203,6 +203,19 @@ def is_faiss_available():
|
|||||||
return _faiss_available
|
return _faiss_available
|
||||||
|
|
||||||
|
|
||||||
|
def torch_only_method(fn):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not _torch_available:
|
||||||
|
raise ImportError(
|
||||||
|
"You need to install pytorch to use this method or class, "
|
||||||
|
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def is_sklearn_available():
|
def is_sklearn_available():
|
||||||
return _has_sklearn
|
return _has_sklearn
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from tokenizers.normalizers import Lowercase, Sequence, Strip, unicode_normalize
|
|||||||
from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit
|
from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit
|
||||||
from tokenizers.processors import BertProcessing
|
from tokenizers.processors import BertProcessing
|
||||||
|
|
||||||
from .file_utils import cached_path, is_torch_available
|
from .file_utils import cached_path, is_torch_available, torch_only_method
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
@@ -48,12 +48,16 @@ if is_torch_available():
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
VOCAB_FILES_NAMES = {"pretrained_vocab_file": "vocab.bin", "vocab_file": "vocab.txt"}
|
VOCAB_FILES_NAMES = {
|
||||||
|
"pretrained_vocab_file": "vocab.pkl",
|
||||||
|
"pretrained_vocab_file_torch": "vocab.bin",
|
||||||
|
"vocab_file": "vocab.txt",
|
||||||
|
}
|
||||||
VOCAB_FILES_NAMES_FAST = {"pretrained_vocab_file": "vocab.json", "vocab_file": "vocab.json"}
|
VOCAB_FILES_NAMES_FAST = {"pretrained_vocab_file": "vocab.json", "vocab_file": "vocab.json"}
|
||||||
|
|
||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
"pretrained_vocab_file": {
|
"pretrained_vocab_file": {
|
||||||
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
|
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.pkl",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,8 +143,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
File containing the vocabulary (from the original implementation).
|
File containing the vocabulary (from the original implementation).
|
||||||
pretrained_vocab_file (:obj:`str`, `optional`):
|
pretrained_vocab_file (:obj:`str`, `optional`):
|
||||||
File containing the vocabulary as saved with the :obj:`save_pretrained()` method.
|
File containing the vocabulary as saved with the :obj:`save_pretrained()` method.
|
||||||
never_split (xxx, `optional`):
|
never_split (:obj:`List[str]`, `optional`):
|
||||||
Fill me with intesting stuff.
|
List of tokens that should never be split. If no list is specified, will simply use the existing
|
||||||
|
special tokens.
|
||||||
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
|
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
|
||||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
token instead.
|
token instead.
|
||||||
@@ -165,7 +170,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
lower_case=False,
|
lower_case=False,
|
||||||
delimiter=None,
|
delimiter=None,
|
||||||
vocab_file=None,
|
vocab_file=None,
|
||||||
pretrained_vocab_file=None,
|
pretrained_vocab_file: str = None,
|
||||||
never_split=None,
|
never_split=None,
|
||||||
unk_token="<unk>",
|
unk_token="<unk>",
|
||||||
eos_token="<eos>",
|
eos_token="<eos>",
|
||||||
@@ -197,23 +202,40 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
self.moses_tokenizer = sm.MosesTokenizer(language)
|
self.moses_tokenizer = sm.MosesTokenizer(language)
|
||||||
self.moses_detokenizer = sm.MosesDetokenizer(language)
|
self.moses_detokenizer = sm.MosesDetokenizer(language)
|
||||||
|
|
||||||
|
# This try... catch... is not beautiful but honestly this tokenizer was not made to be used
|
||||||
|
# in a library like ours, at all.
|
||||||
try:
|
try:
|
||||||
|
vocab_dict = None
|
||||||
if pretrained_vocab_file is not None:
|
if pretrained_vocab_file is not None:
|
||||||
# Hack because, honestly this tokenizer was not made to be used
|
# Priority on pickle files (support PyTorch and TF)
|
||||||
# in a library like ours, at all.
|
with open(pretrained_vocab_file, "rb") as f:
|
||||||
vocab_dict = torch.load(pretrained_vocab_file)
|
vocab_dict = pickle.load(f)
|
||||||
|
|
||||||
|
# Loading a torch-saved transfo-xl vocab dict with pickle results in an integer
|
||||||
|
# Entering this if statement means that we tried to load a torch-saved file with pickle, and we failed.
|
||||||
|
# We therefore load it with torch, if it's available.
|
||||||
|
if type(vocab_dict) == int:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ImportError(
|
||||||
|
"Not trying to load dict with PyTorch as you need to install pytorch to load "
|
||||||
|
"from a PyTorch pretrained vocabulary, "
|
||||||
|
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
|
||||||
|
)
|
||||||
|
vocab_dict = torch.load(pretrained_vocab_file)
|
||||||
|
|
||||||
|
if vocab_dict is not None:
|
||||||
for key, value in vocab_dict.items():
|
for key, value in vocab_dict.items():
|
||||||
if key not in self.__dict__:
|
if key not in self.__dict__:
|
||||||
self.__dict__[key] = value
|
self.__dict__[key] = value
|
||||||
|
elif vocab_file is not None:
|
||||||
if vocab_file is not None:
|
|
||||||
self.build_vocab()
|
self.build_vocab()
|
||||||
except Exception:
|
|
||||||
|
except Exception as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unable to parse file {}. Unknown format. "
|
"Unable to parse file {}. Unknown format. "
|
||||||
"If you tried to load a model saved through TransfoXLTokenizerFast,"
|
"If you tried to load a model saved through TransfoXLTokenizerFast,"
|
||||||
"please note they are not compatible.".format(pretrained_vocab_file)
|
"please note they are not compatible.".format(pretrained_vocab_file)
|
||||||
)
|
) from e
|
||||||
|
|
||||||
if vocab_file is not None:
|
if vocab_file is not None:
|
||||||
self.build_vocab()
|
self.build_vocab()
|
||||||
@@ -286,7 +308,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"])
|
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"])
|
||||||
else:
|
else:
|
||||||
vocab_file = vocab_path
|
vocab_file = vocab_path
|
||||||
torch.save(self.__dict__, vocab_file)
|
with open(vocab_file, "wb") as f:
|
||||||
|
pickle.dump(self.__dict__, f)
|
||||||
return (vocab_file,)
|
return (vocab_file,)
|
||||||
|
|
||||||
def build_vocab(self):
|
def build_vocab(self):
|
||||||
@@ -309,6 +332,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
logger.info("final vocab size {} from {} unique tokens".format(len(self), len(self.counter)))
|
logger.info("final vocab size {} from {} unique tokens".format(len(self), len(self.counter)))
|
||||||
|
|
||||||
|
@torch_only_method
|
||||||
def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
|
def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info("encoding file {} ...".format(path))
|
logger.info("encoding file {} ...".format(path))
|
||||||
@@ -326,6 +350,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
return encoded
|
return encoded
|
||||||
|
|
||||||
|
@torch_only_method
|
||||||
def encode_sents(self, sents, ordered=False, verbose=False):
|
def encode_sents(self, sents, ordered=False, verbose=False):
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info("encoding {} sents ...".format(len(sents)))
|
logger.info("encoding {} sents ...".format(len(sents)))
|
||||||
@@ -436,6 +461,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
out_string = self.moses_detokenizer.detokenize(tokens)
|
out_string = self.moses_detokenizer.detokenize(tokens)
|
||||||
return detokenize_numbers(out_string).strip()
|
return detokenize_numbers(out_string).strip()
|
||||||
|
|
||||||
|
@torch_only_method
|
||||||
def convert_to_tensor(self, symbols):
|
def convert_to_tensor(self, symbols):
|
||||||
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
|
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
|
||||||
|
|
||||||
@@ -706,6 +732,7 @@ class LMShuffledIterator(object):
|
|||||||
for idx in epoch_indices:
|
for idx in epoch_indices:
|
||||||
yield self.data[idx]
|
yield self.data[idx]
|
||||||
|
|
||||||
|
@torch_only_method
|
||||||
def stream_iterator(self, sent_stream):
|
def stream_iterator(self, sent_stream):
|
||||||
# streams for each data in the batch
|
# streams for each data in the batch
|
||||||
streams = [None] * self.bsz
|
streams = [None] * self.bsz
|
||||||
@@ -795,6 +822,7 @@ class LMMultiFileIterator(LMShuffledIterator):
|
|||||||
|
|
||||||
class TransfoXLCorpus(object):
|
class TransfoXLCorpus(object):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@torch_only_method
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate a pre-processed corpus.
|
Instantiate a pre-processed corpus.
|
||||||
@@ -892,10 +920,14 @@ class TransfoXLCorpus(object):
|
|||||||
data_iter = LMOrderedIterator(data, *args, **kwargs)
|
data_iter = LMOrderedIterator(data, *args, **kwargs)
|
||||||
elif self.dataset == "lm1b":
|
elif self.dataset == "lm1b":
|
||||||
data_iter = LMShuffledIterator(data, *args, **kwargs)
|
data_iter = LMShuffledIterator(data, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
data_iter = None
|
||||||
|
raise ValueError(f"Split not recognized: {split}")
|
||||||
|
|
||||||
return data_iter
|
return data_iter
|
||||||
|
|
||||||
|
|
||||||
|
@torch_only_method
|
||||||
def get_lm_corpus(datadir, dataset):
|
def get_lm_corpus(datadir, dataset):
|
||||||
fn = os.path.join(datadir, "cache.pt")
|
fn = os.path.join(datadir, "cache.pt")
|
||||||
fn_pickle = os.path.join(datadir, "cache.pkl")
|
fn_pickle = os.path.join(datadir, "cache.pkl")
|
||||||
|
|||||||
Reference in New Issue
Block a user