Merge remote-tracking branch 'upstream/master'
This commit is contained in:
@@ -21,4 +21,4 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model,
|
||||
from .optimization import BertAdam
|
||||
from .optimization_openai import OpenAIAdam
|
||||
|
||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
|
||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME
|
||||
|
||||
@@ -5,11 +5,13 @@ Copyright by the AllenNLP authors.
|
||||
"""
|
||||
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
||||
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import fnmatch
|
||||
from functools import wraps
|
||||
from hashlib import sha256
|
||||
import sys
|
||||
@@ -33,6 +35,9 @@ except (AttributeError, ImportError):
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -188,17 +193,30 @@ def get_from_cache(url, cache_dir=None):
|
||||
if url.startswith("s3://"):
|
||||
etag = s3_etag(url)
|
||||
else:
|
||||
response = requests.head(url, allow_redirects=True)
|
||||
if response.status_code != 200:
|
||||
raise IOError("HEAD request failed for url {} with status code {}"
|
||||
.format(url, response.status_code))
|
||||
etag = response.headers.get("ETag")
|
||||
try:
|
||||
response = requests.head(url, allow_redirects=True)
|
||||
if response.status_code != 200:
|
||||
etag = None
|
||||
else:
|
||||
etag = response.headers.get("ETag")
|
||||
except EnvironmentError:
|
||||
etag = None
|
||||
|
||||
if sys.version_info[0] == 2 and etag is not None:
|
||||
etag = etag.decode('utf-8')
|
||||
filename = url_to_filename(url, etag)
|
||||
|
||||
# get cache path to put the file
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
|
||||
# If we don't have a connection (etag is None) and can't identify the file
|
||||
# try to get the last downloaded one
|
||||
if not os.path.exists(cache_path) and etag is None:
|
||||
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
|
||||
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
|
||||
if matching_files:
|
||||
cache_path = os.path.join(cache_dir, matching_files[-1])
|
||||
|
||||
if not os.path.exists(cache_path):
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
@@ -223,8 +241,11 @@ def get_from_cache(url, cache_dir=None):
|
||||
logger.info("creating metadata file for %s", cache_path)
|
||||
meta = {'url': url, 'etag': etag}
|
||||
meta_path = cache_path + '.json'
|
||||
with open(meta_path, 'w', encoding="utf-8") as meta_file:
|
||||
json.dump(meta, meta_file)
|
||||
with open(meta_path, 'w') as meta_file:
|
||||
output_string = json.dumps(meta)
|
||||
if sys.version_info[0] == 2 and isinstance(output_string, str):
|
||||
output_string = unicode(output_string, 'utf-8') # The beauty of python 2
|
||||
meta_file.write(output_string)
|
||||
|
||||
logger.info("removing temp file %s", temp_file.name)
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,8 +45,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
||||
}
|
||||
CONFIG_NAME = 'bert_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
BERT_CONFIG_NAME = 'bert_config.json'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||
@@ -220,6 +219,11 @@ class BertConfig(object):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path):
|
||||
""" Save this instance to a json file."""
|
||||
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
||||
except ImportError:
|
||||
@@ -581,13 +585,16 @@ class BertPreTrainedModel(nn.Module):
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
if not os.path.exists(config_file):
|
||||
# Backward compatibility with old naming format
|
||||
config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
|
||||
config = BertConfig.from_json_file(config_file)
|
||||
logger.info("Model config {}".format(config))
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None and not from_tf:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||
state_dict = torch.load(weights_path, map_location='cpu')
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
@@ -930,7 +937,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts
|
||||
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
|
||||
@@ -34,7 +34,7 @@ import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,9 +42,6 @@ logger = logging.getLogger(__name__)
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"}
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"}
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
"""
|
||||
@@ -180,6 +177,11 @@ class GPT2Config(object):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path):
|
||||
""" Save this instance to a json file."""
|
||||
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
def __init__(self, nf, nx):
|
||||
@@ -216,7 +218,7 @@ class Attention(nn.Module):
|
||||
w = w / math.sqrt(v.size(-1))
|
||||
nd, ns = w.size(-2), w.size(-1)
|
||||
b = self.bias[:, :, ns-nd:ns, :ns]
|
||||
w = w * b - 1e10 * (1 - b)
|
||||
w = w * b - 1e4 * (1 - b)
|
||||
|
||||
w = nn.Softmax(dim=-1)(w)
|
||||
return torch.matmul(w, v)
|
||||
@@ -416,7 +418,7 @@ class GPT2PreTrainedModel(nn.Module):
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None and not from_tf:
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
|
||||
return load_tf_weights_in_gpt2(model, resolved_archive_file)
|
||||
|
||||
@@ -34,7 +34,7 @@ import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,8 +42,6 @@ logger = logging.getLogger(__name__)
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
||||
@@ -225,6 +223,11 @@ class OpenAIGPTConfig(object):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path):
|
||||
""" Save this instance to a json file."""
|
||||
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
def __init__(self, nf, rf, nx):
|
||||
@@ -473,7 +476,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None and not from_tf:
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
|
||||
return load_tf_weights_in_openai_gpt(model, resolved_archive_file)
|
||||
@@ -605,14 +608,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
return
|
||||
# Update config
|
||||
self.config.n_special = num_special_tokens
|
||||
# # Build new embeddings and initialize
|
||||
# Build new embeddings and initialize all new embeddings (in particular the special tokens)
|
||||
old_embed = self.tokens_embed
|
||||
self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
|
||||
# Initialize all new embeddings (in particular the special tokens)
|
||||
self.tokens_embed.to(old_embed.weight.device)
|
||||
self.init_weights(self.tokens_embed)
|
||||
# Copy word and positional embeddings from the previous weights
|
||||
self.tokens_embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :]
|
||||
self.tokens_embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
|
||||
# Copy word embeddings from the previous weights
|
||||
self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
||||
if position_ids is None:
|
||||
@@ -717,9 +719,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
if lm_labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[:, :-1].contiguous()
|
||||
shift_labels = lm_labels[:, 1:].contiguous()
|
||||
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = lm_labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||
@@ -809,11 +810,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||
losses = []
|
||||
if lm_labels is not None:
|
||||
shift_logits = lm_logits[:, :-1].contiguous()
|
||||
shift_labels = lm_labels[:, 1:].contiguous()
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = lm_labels[..., 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
losses.append(loss_fct(shift_logits.view(-1,
|
||||
shift_logits.size(-1)), shift_labels.view(-1)))
|
||||
losses.append(loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)))
|
||||
if mc_labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
||||
|
||||
@@ -40,7 +40,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,8 +50,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
|
||||
}
|
||||
CONFIG_NAME = 'config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
def build_tf_to_pytorch_map(model, config):
|
||||
@@ -316,6 +315,11 @@ class TransfoXLConfig(object):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path):
|
||||
""" Save this instance to a json file."""
|
||||
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, demb):
|
||||
@@ -940,7 +944,7 @@ class TransfoXLPreTrainedModel(nn.Module):
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None and not from_tf:
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
return load_tf_weights_in_transfo_xl(model, config, pretrained_model_name_or_path)
|
||||
|
||||
@@ -134,6 +134,21 @@ class BertTokenizer(object):
|
||||
tokens.append(self.ids_to_tokens[i])
|
||||
return tokens
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary to a directory or file."""
|
||||
index = 0
|
||||
if os.path.isdir(vocab_path):
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
|
||||
" Please check that the vocabulary is not corrupted!".format(vocab_file))
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
return vocab_file
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -45,6 +46,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
}
|
||||
VOCAB_NAME = 'vocab.json'
|
||||
MERGES_NAME = 'merges.txt'
|
||||
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
@@ -57,6 +59,7 @@ def bytes_to_unicode():
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
_chr = unichr if sys.version_info[0] == 2 else chr
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
@@ -65,7 +68,7 @@ def bytes_to_unicode():
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
cs = [_chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
def get_pairs(word):
|
||||
@@ -94,9 +97,15 @@ class GPT2Tokenizer(object):
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
special_tokens_file = None
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
|
||||
if not os.path.exists(special_tokens_file):
|
||||
special_tokens_file = None
|
||||
else:
|
||||
logger.info("loading special tokens file {}".format(special_tokens_file))
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
@@ -125,10 +134,14 @@ class GPT2Tokenizer(object):
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
|
||||
if special_tokens_file and 'special_tokens' not in kwargs:
|
||||
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
||||
else:
|
||||
special_tokens = kwargs.pop('special_tokens', [])
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
|
||||
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.encoder = json.load(open(vocab_file))
|
||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||
@@ -143,8 +156,25 @@ class GPT2Tokenizer(object):
|
||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
self.set_special_tokens(special_tokens)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder)
|
||||
return len(self.encoder) + len(self.special_tokens)
|
||||
|
||||
def set_special_tokens(self, special_tokens):
|
||||
""" Add a list of additional tokens to the encoder.
|
||||
The additional tokens are indexed starting from the last index of the
|
||||
current vocabulary in the order of the `special_tokens` list.
|
||||
"""
|
||||
if not special_tokens:
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
return
|
||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
|
||||
logger.info("Special tokens {}".format(self.special_tokens))
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
@@ -187,20 +217,85 @@ class GPT2Tokenizer(object):
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
def tokenize(self, text):
|
||||
""" Tokenize a string. """
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
if len(bpe_tokens) > self.max_len:
|
||||
token = ''.join(self.byte_encoder[ord(b)] for b in token)
|
||||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
""" Converts a sequence of tokens into ids using the vocab. """
|
||||
ids = []
|
||||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
||||
if tokens in self.special_tokens:
|
||||
return self.special_tokens[tokens]
|
||||
else:
|
||||
return self.encoder.get(tokens, 0)
|
||||
for token in tokens:
|
||||
if token in self.special_tokens:
|
||||
ids.append(self.special_tokens[token])
|
||||
else:
|
||||
ids.append(self.encoder.get(token, 0))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
|
||||
" sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)
|
||||
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
||||
)
|
||||
return bpe_tokens
|
||||
return ids
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
if i in self.special_tokens_decoder:
|
||||
if not skip_special_tokens:
|
||||
tokens.append(self.special_tokens_decoder[i])
|
||||
else:
|
||||
tokens.append(self.decoder[i])
|
||||
return tokens
|
||||
|
||||
def encode(self, text):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||
return text
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||
return
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
|
||||
|
||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write(u'#version: 0.2\n')
|
||||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(merge_file))
|
||||
index = token_index
|
||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||
index += 1
|
||||
|
||||
index = len(self.encoder)
|
||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
||||
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file, special_tokens_file
|
||||
|
||||
@@ -41,6 +41,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
}
|
||||
VOCAB_NAME = 'vocab.json'
|
||||
MERGES_NAME = 'merges.txt'
|
||||
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
|
||||
|
||||
def get_pairs(word):
|
||||
"""
|
||||
@@ -86,9 +87,15 @@ class OpenAIGPTTokenizer(object):
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
special_tokens_file = None
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
|
||||
if not os.path.exists(special_tokens_file):
|
||||
special_tokens_file = None
|
||||
else:
|
||||
logger.info("loading special tokens file {}".format(special_tokens_file))
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
@@ -117,7 +124,11 @@ class OpenAIGPTTokenizer(object):
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
|
||||
if special_tokens_file and 'special_tokens' not in kwargs:
|
||||
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
||||
else:
|
||||
special_tokens = kwargs.pop('special_tokens', [])
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
|
||||
@@ -139,6 +150,8 @@ class OpenAIGPTTokenizer(object):
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {}
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
self.set_special_tokens(special_tokens)
|
||||
|
||||
def __len__(self):
|
||||
@@ -250,14 +263,51 @@ class OpenAIGPTTokenizer(object):
|
||||
tokens.append(self.decoder[i])
|
||||
return tokens
|
||||
|
||||
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False):
|
||||
def encode(self, text):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
|
||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||
if clean_up_tokenization_spaces:
|
||||
out_string = out_string.replace('<unk>', '')
|
||||
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ','
|
||||
).replace(" n't", "n't").replace(" 'm", "'m").replace(" 're", "'re").replace(" do not", " don't"
|
||||
).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m "
|
||||
).replace(" 've", "'ve")
|
||||
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
|
||||
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||
return
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
|
||||
|
||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write(u'#version: 0.2\n')
|
||||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(merge_file))
|
||||
index = token_index
|
||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||
index += 1
|
||||
|
||||
index = len(self.encoder)
|
||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
||||
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file, special_tokens_file
|
||||
|
||||
@@ -63,7 +63,10 @@ class TransfoXLTokenizer(object):
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
else:
|
||||
vocab_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
@@ -141,6 +144,14 @@ class TransfoXLTokenizer(object):
|
||||
else:
|
||||
raise ValueError('No <unkown> token in vocabulary')
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary to a directory or file."""
|
||||
index = 0
|
||||
if os.path.isdir(vocab_path):
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
torch.save(self.__dict__, vocab_file)
|
||||
return vocab_file
|
||||
|
||||
def build_vocab(self):
|
||||
if self.vocab_file:
|
||||
print('building vocab from {}'.format(self.vocab_file))
|
||||
@@ -245,82 +256,24 @@ class TransfoXLTokenizer(object):
|
||||
def __len__(self):
|
||||
return len(self.idx2sym)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
if text in self.never_split:
|
||||
return [text]
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def whitespace_tokenize(self, text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
if self.delimiter == '':
|
||||
tokens = text
|
||||
else:
|
||||
tokens = text.split(self.delimiter)
|
||||
return tokens
|
||||
|
||||
def tokenize(self, line, add_eos=False, add_double_eos=False):
|
||||
line = self._clean_text(line)
|
||||
line = line.strip()
|
||||
# convert to lower case
|
||||
if self.lower_case:
|
||||
line = line.lower()
|
||||
|
||||
symbols = self.whitespace_tokenize(line)
|
||||
|
||||
split_symbols = []
|
||||
for symbol in symbols:
|
||||
if self.lower_case and symbol not in self.never_split:
|
||||
symbol = symbol.lower()
|
||||
symbol = self._run_strip_accents(symbol)
|
||||
split_symbols.extend(self._run_split_on_punc(symbol))
|
||||
# empty delimiter '' will evaluate False
|
||||
if self.delimiter == '':
|
||||
symbols = line
|
||||
else:
|
||||
symbols = line.split(self.delimiter)
|
||||
|
||||
if add_double_eos: # lm1b
|
||||
return ['<S>'] + split_symbols + ['<S>']
|
||||
return ['<S>'] + symbols + ['<S>']
|
||||
elif add_eos:
|
||||
return split_symbols + ['<eos>']
|
||||
return symbols + ['<eos>']
|
||||
else:
|
||||
return split_symbols
|
||||
return symbols
|
||||
|
||||
|
||||
class LMOrderedIterator(object):
|
||||
@@ -631,42 +584,3 @@ def get_lm_corpus(datadir, dataset):
|
||||
torch.save(corpus, fn)
|
||||
|
||||
return corpus
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user