added wordpiece - updated readme
This commit is contained in:
23
README.md
23
README.md
@@ -1,2 +1,23 @@
|
|||||||
# pytorch-pretrained-BERT
|
# pytorch-pretrained-BERT
|
||||||
A PyTorch version of Google's pretrained BERT model
|
A PyTorch version of Google's pretrained BERT model as described in
|
||||||
|
|
||||||
|
No bells and whitles, just:
|
||||||
|
- [one class](bert_model.py) with a clean commented version of Google's BERT model that can load the weights pre-trained by Google's authors,
|
||||||
|
- [another class](data_processor.py) with all you need to pre- and post-process text data for the model (tokenize and encode),
|
||||||
|
- and [a script](download_weigths.sh) to download Google's pre-trained weights.
|
||||||
|
|
||||||
|
Here is how to use these:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .bert_model import BERT
|
||||||
|
from .data_processor import DataProcessor
|
||||||
|
|
||||||
|
bert_model = BERT(bert_model_path='.')
|
||||||
|
data_processor = DataProcessor(bert_vocab_path='.')
|
||||||
|
|
||||||
|
input_sentence = "We are playing with the BERT model."
|
||||||
|
|
||||||
|
tensor_input = data_processor.encode(input_sentence)
|
||||||
|
tensor_output = bert_model(prepared_input)
|
||||||
|
output_sentence = data_processor.decode(tensor_output)
|
||||||
|
```
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from typing import NamedTuple, List
|
|||||||
import copy
|
import copy
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
@@ -273,10 +272,7 @@ class BERT(torch.nn.Module):
|
|||||||
config = BERTConfig(
|
config = BERTConfig(
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
num_heads,
|
num_heads,
|
||||||
embedding_dropout_probability,
|
dropout_probability,
|
||||||
attention_dropout_probability,
|
|
||||||
residual_dropout_probability,
|
|
||||||
activation_function,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# the embedding size is vocab_size + n_special embeddings + n_ctx
|
# the embedding size is vocab_size + n_special embeddings + n_ctx
|
||||||
@@ -288,7 +284,7 @@ class BERT(torch.nn.Module):
|
|||||||
self.num_output_layers = 1 + num_layers
|
self.num_output_layers = 1 + num_layers
|
||||||
|
|
||||||
self.embed = torch.nn.Embedding(embedding_size, embedding_dim)
|
self.embed = torch.nn.Embedding(embedding_size, embedding_dim)
|
||||||
self.drop = torch.nn.Dropout(embedding_dropout_probability)
|
self.drop = torch.nn.Dropout(dropout_probability)
|
||||||
|
|
||||||
block = Block(n_ctx, config, scale=True)
|
block = Block(n_ctx, config, scale=True)
|
||||||
self.h = torch.nn.ModuleList([copy.deepcopy(block) for _ in range(num_layers)])
|
self.h = torch.nn.ModuleList([copy.deepcopy(block) for _ in range(num_layers)])
|
||||||
@@ -332,16 +328,13 @@ class BERT(torch.nn.Module):
|
|||||||
names: List[str] = _PARAMETER_NAMES) -> None:
|
names: List[str] = _PARAMETER_NAMES) -> None:
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
|
|
||||||
logger.info(f"loading weights from {bert_model_path}")
|
|
||||||
# if `file_path` is a URL, redirect to the cache
|
|
||||||
|
|
||||||
with tarfile.open(bert_model_path) as tmp:
|
with tarfile.open(bert_model_path) as tmp:
|
||||||
num_params_files = len([member for member in tmp.getmembers() if member.name.endswith('.npy')])
|
num_params_files = len([member for member in tmp.getmembers() if member.name.endswith('.npy')])
|
||||||
shapesfile = tmp.extractfile('model/params_shapes.json')
|
shapesfile = tmp.extractfile('model/params_shapes.json')
|
||||||
if shapesfile:
|
if shapesfile:
|
||||||
shapes = json.loads(shapesfile.read())
|
shapes = json.loads(shapesfile.read())
|
||||||
else:
|
else:
|
||||||
raise ConfigurationError("unable to find model/params_shapes.json in the archive")
|
raise Exception("unable to find model/params_shapes.json in the archive")
|
||||||
|
|
||||||
# numpy can't read from a tarfile directly, so we need a workaround
|
# numpy can't read from a tarfile directly, so we need a workaround
|
||||||
# https://github.com/numpy/numpy/issues/7989#issuecomment-341656702
|
# https://github.com/numpy/numpy/issues/7989#issuecomment-341656702
|
||||||
|
|||||||
@@ -1,22 +1,129 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Tensor2Tensor Authors and Thomas Wolf
|
||||||
"""
|
"""
|
||||||
Prepare input data for Google's BERT Model.
|
Prepare input data for Google's BERT Model using WordPiece tokenization
|
||||||
|
and build arrays of word, position and sentence embeddings.
|
||||||
|
|
||||||
Contains some functions from tensor2tensor library: https://github.com/tensorflow/tensor2tensor
|
The WordPiece tokenization classes and functions are taken from the tensor2tensor library:
|
||||||
|
https://github.com/tensorflow/tensor2tensor
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
from typing import NamedTuple, List, Union, Tuple
|
from typing import NamedTuple, List, Union, Tuple
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
import six
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
import collections
|
||||||
|
import unicodedata
|
||||||
|
from itertools import chain
|
||||||
|
from six.moves import range # pylint: disable=redefined-builtin
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
TokenizedSentence = List[str]
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
TokenizedInput = Union[Tuple[TokenizedSentence, TokenizedSentence], TokenizedSentence]
|
|
||||||
|
|
||||||
class DataProcessor():
|
# Reserved tokens for things like padding and EOS symbols.
|
||||||
def __init__(self, vocab_path):
|
PAD = "<pad>"
|
||||||
self.encoder_file_path = encoder_file_path
|
EOS = "<EOS>"
|
||||||
self.token_indexer = json.load(open(vocab_path))
|
RESERVED_TOKENS = [PAD, EOS]
|
||||||
|
NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
|
||||||
|
PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0
|
||||||
|
EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1
|
||||||
|
|
||||||
|
# Regular expression for unescaping token strings.
|
||||||
|
# '\u' is converted to '_'
|
||||||
|
# '\\' is converted to '\'
|
||||||
|
# '\213;' is converted to unichr(213)
|
||||||
|
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
|
||||||
|
_ESCAPE_CHARS = set(u"\\_u;0123456789")
|
||||||
|
|
||||||
|
# This set contains all letter and number characters.
|
||||||
|
_ALPHANUMERIC_CHAR_SET = set(
|
||||||
|
six.unichr(i) for i in range(sys.maxunicode)
|
||||||
|
if (unicodedata.category(six.unichr(i)).startswith("L") or
|
||||||
|
unicodedata.category(six.unichr(i)).startswith("N")))
|
||||||
|
|
||||||
|
|
||||||
def tokenize(text):
|
# Unicode utility functions that work with Python 2 and 3
|
||||||
"""Encode a unicode string as a list of tokens.
|
def native_to_unicode(s):
|
||||||
|
if is_unicode(s):
|
||||||
|
return s
|
||||||
|
try:
|
||||||
|
return to_unicode(s)
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
res = to_unicode(s, ignore_errors=True)
|
||||||
|
logger.info("Ignoring Unicode error, outputting: {}".format(res))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def unicode_to_native(s):
|
||||||
|
if six.PY2:
|
||||||
|
return s.encode("utf-8") if is_unicode(s) else s
|
||||||
|
else:
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def is_unicode(s):
|
||||||
|
if six.PY2:
|
||||||
|
if isinstance(s, unicode):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if isinstance(s, str):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def to_unicode(s, ignore_errors=False):
|
||||||
|
if is_unicode(s):
|
||||||
|
return s
|
||||||
|
error_mode = "ignore" if ignore_errors else "strict"
|
||||||
|
return s.decode("utf-8", errors=error_mode)
|
||||||
|
|
||||||
|
|
||||||
|
def to_unicode_ignore_errors(s):
|
||||||
|
return to_unicode(s, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def strip_ids(ids, ids_to_strip):
|
||||||
|
"""Strip ids_to_strip from the end ids."""
|
||||||
|
ids = list(ids)
|
||||||
|
while ids and ids[-1] in ids_to_strip:
|
||||||
|
ids.pop()
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def tokenizer_encode(text):
|
||||||
|
"""A simple invertible tokenizer (in words).
|
||||||
|
|
||||||
|
Converts from a unicode string to a list of tokens
|
||||||
|
(represented as Unicode strings).
|
||||||
|
|
||||||
|
This tokenizer has the following desirable properties:
|
||||||
|
- It is invertible.
|
||||||
|
- Alphanumeric characters are broken away from non-alphanumeric characters.
|
||||||
|
- A single space between words does not produce an extra token.
|
||||||
|
- The full Unicode punctuation and separator set is recognized.
|
||||||
|
|
||||||
|
The tokenization algorithm is as follows:
|
||||||
|
|
||||||
|
1. Split the text into a list of tokens, splitting at every boundary of an
|
||||||
|
alphanumeric character and a non-alphanumeric character. This produces
|
||||||
|
a list which alternates between "alphanumeric tokens"
|
||||||
|
(strings of alphanumeric characters) and "non-alphanumeric tokens"
|
||||||
|
(strings of non-alphanumeric characters).
|
||||||
|
|
||||||
|
2. Remove every token consisting of a single space, unless it is
|
||||||
|
the very first or very last token in the list. These tokens are now
|
||||||
|
implied by the fact that there are two adjacent alphanumeric tokens.
|
||||||
|
|
||||||
|
e.g. u"Dude - that's so cool."
|
||||||
|
-> [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."]
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: a unicode string
|
text: a unicode string
|
||||||
@@ -40,7 +147,7 @@ class DataProcessor():
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def detokenize(tokens):
|
def tokenizer_decode(tokens):
|
||||||
"""Decode a list of tokens to a unicode string.
|
"""Decode a list of tokens to a unicode string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -56,34 +163,657 @@ class DataProcessor():
|
|||||||
ret.append(token)
|
ret.append(token)
|
||||||
return "".join(ret)
|
return "".join(ret)
|
||||||
|
|
||||||
def encode(input_sentences: List[TokenizedInput]) -> np.array:
|
|
||||||
|
class TextEncoder(object):
|
||||||
|
"""Base class for converting from ints to/from human readable strings."""
|
||||||
|
|
||||||
|
def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
|
||||||
|
self._num_reserved_ids = num_reserved_ids
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_reserved_ids(self):
|
||||||
|
return self._num_reserved_ids
|
||||||
|
|
||||||
|
def encode(self, s):
|
||||||
|
"""Transform a human-readable string into a sequence of int ids.
|
||||||
|
|
||||||
|
The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
|
||||||
|
num_reserved_ids) are reserved.
|
||||||
|
|
||||||
|
EOS is not appended.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s: human-readable string to be converted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ids: list of integers
|
||||||
|
"""
|
||||||
|
return [int(w) + self._num_reserved_ids for w in s.split()]
|
||||||
|
|
||||||
|
def decode(self, ids, strip_extraneous=False):
|
||||||
|
"""Transform a sequence of int ids into a human-readable string.
|
||||||
|
|
||||||
|
EOS is not expected in ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: list of integers to be converted.
|
||||||
|
strip_extraneous: bool, whether to strip off extraneous tokens
|
||||||
|
(EOS and PAD).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
s: human-readable string.
|
||||||
|
"""
|
||||||
|
if strip_extraneous:
|
||||||
|
ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
|
||||||
|
return " ".join(self.decode_list(ids))
|
||||||
|
|
||||||
|
def decode_list(self, ids):
|
||||||
|
"""Transform a sequence of int ids into a their string versions.
|
||||||
|
|
||||||
|
This method supports transforming individual input/output ids to their
|
||||||
|
string versions so that sequence to/from text conversions can be visualized
|
||||||
|
in a human readable format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: list of integers to be converted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
strs: list of human-readable string.
|
||||||
|
"""
|
||||||
|
decoded_ids = []
|
||||||
|
for id_ in ids:
|
||||||
|
if 0 <= id_ < self._num_reserved_ids:
|
||||||
|
decoded_ids.append(RESERVED_TOKENS[int(id_)])
|
||||||
|
else:
|
||||||
|
decoded_ids.append(id_ - self._num_reserved_ids)
|
||||||
|
return [str(d) for d in decoded_ids]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def _escape_token(token, alphabet):
|
||||||
|
"""Escape away underscores and OOV characters and append '_'.
|
||||||
|
|
||||||
|
This allows the token to be expressed as the concatenation of a list
|
||||||
|
of subtokens from the vocabulary. The underscore acts as a sentinel
|
||||||
|
which allows us to invertibly concatenate multiple such lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: A unicode string to be escaped.
|
||||||
|
alphabet: A set of all characters in the vocabulary's alphabet.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
escaped_token: An escaped unicode string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the provided token is not unicode.
|
||||||
|
"""
|
||||||
|
if not isinstance(token, six.text_type):
|
||||||
|
raise ValueError("Expected string type for token, got %s" % type(token))
|
||||||
|
|
||||||
|
token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u")
|
||||||
|
ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token]
|
||||||
|
return u"".join(ret) + "_"
|
||||||
|
|
||||||
|
|
||||||
|
def _unescape_token(escaped_token):
|
||||||
|
"""Inverse of _escape_token().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
escaped_token: a unicode string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
token: a unicode string
|
||||||
|
"""
|
||||||
|
|
||||||
|
def match(m):
|
||||||
|
if m.group(1) is None:
|
||||||
|
return u"_" if m.group(0) == u"\\u" else u"\\"
|
||||||
|
|
||||||
|
try:
|
||||||
|
return six.unichr(int(m.group(1)))
|
||||||
|
except (ValueError, OverflowError) as _:
|
||||||
|
return u"\u3013" # Unicode for undefined character.
|
||||||
|
|
||||||
|
trimmed = escaped_token[:-1] if escaped_token.endswith("_") else escaped_token
|
||||||
|
return _UNESCAPE_REGEX.sub(match, trimmed)
|
||||||
|
|
||||||
|
|
||||||
|
class SubwordTextEncoder(TextEncoder):
|
||||||
|
"""Class for invertibly encoding text using a limited vocabulary.
|
||||||
|
|
||||||
|
Invertibly encodes a native string as a sequence of subtokens from a limited
|
||||||
|
vocabulary.
|
||||||
|
|
||||||
|
A SubwordTextEncoder is built from a corpus (so it is tailored to the text in
|
||||||
|
the corpus), and stored to a file. See text_encoder_build_subword.py.
|
||||||
|
|
||||||
|
It can then be loaded and used to encode/decode any text.
|
||||||
|
|
||||||
|
Encoding has four phases:
|
||||||
|
|
||||||
|
1. Tokenize into a list of tokens. Each token is a unicode string of either
|
||||||
|
all alphanumeric characters or all non-alphanumeric characters. We drop
|
||||||
|
tokens consisting of a single space that are between two alphanumeric
|
||||||
|
tokens.
|
||||||
|
|
||||||
|
2. Escape each token. This escapes away special and out-of-vocabulary
|
||||||
|
characters, and makes sure that each token ends with an underscore, and
|
||||||
|
has no other underscores.
|
||||||
|
|
||||||
|
3. Represent each escaped token as a the concatenation of a list of subtokens
|
||||||
|
from the limited vocabulary. Subtoken selection is done greedily from
|
||||||
|
beginning to end. That is, we construct the list in order, always picking
|
||||||
|
the longest subtoken in our vocabulary that matches a prefix of the
|
||||||
|
remaining portion of the encoded token.
|
||||||
|
|
||||||
|
4. Concatenate these lists. This concatenation is invertible due to the
|
||||||
|
fact that the trailing underscores indicate when one list is finished.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, filename=None):
|
||||||
|
"""Initialize and read from a file, if provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: filename from which to read vocab. If None, do not load a
|
||||||
|
vocab
|
||||||
|
"""
|
||||||
|
self._alphabet = set()
|
||||||
|
self.filename = filename
|
||||||
|
if filename is not None:
|
||||||
|
self._load_from_file(filename)
|
||||||
|
super(SubwordTextEncoder, self).__init__()
|
||||||
|
|
||||||
|
def encode(self, s):
|
||||||
|
"""Converts a native string to a list of subtoken ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s: a native string.
|
||||||
|
Returns:
|
||||||
|
a list of integers in the range [0, vocab_size)
|
||||||
|
"""
|
||||||
|
return self._tokens_to_subtoken_ids(
|
||||||
|
tokenizer_encode(native_to_unicode(s)))
|
||||||
|
|
||||||
|
def encode_without_tokenizing(self, token_text):
|
||||||
|
"""Converts string to list of subtoken ids without calling tokenizer.
|
||||||
|
|
||||||
|
This treats `token_text` as a single token and directly converts it
|
||||||
|
to subtoken ids. This may be useful when the default tokenizer doesn't
|
||||||
|
do what we want (e.g., when encoding text with tokens composed of lots of
|
||||||
|
nonalphanumeric characters). It is then up to the caller to make sure that
|
||||||
|
raw text is consistently converted into tokens. Only use this if you are
|
||||||
|
sure that `encode` doesn't suit your needs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_text: A native string representation of a single token.
|
||||||
|
Returns:
|
||||||
|
A list of subword token ids; i.e., integers in the range [0, vocab_size).
|
||||||
|
"""
|
||||||
|
return self._tokens_to_subtoken_ids([native_to_unicode(token_text)])
|
||||||
|
|
||||||
|
def decode(self, ids, strip_extraneous=False):
|
||||||
|
"""Converts a sequence of subtoken ids to a native string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: a list of integers in the range [0, vocab_size)
|
||||||
|
strip_extraneous: bool, whether to strip off extraneous tokens
|
||||||
|
(EOS and PAD).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a native string
|
||||||
|
"""
|
||||||
|
if strip_extraneous:
|
||||||
|
ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
|
||||||
|
return unicode_to_native(
|
||||||
|
tokenizer_decode(self._subtoken_ids_to_tokens(ids)))
|
||||||
|
|
||||||
|
def decode_list(self, ids):
|
||||||
|
return [self._subtoken_id_to_subtoken_string(s) for s in ids]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
"""The subtoken vocabulary size."""
|
||||||
|
return len(self._all_subtoken_strings)
|
||||||
|
|
||||||
|
def _tokens_to_subtoken_ids(self, tokens):
|
||||||
|
"""Converts a list of tokens to a list of subtoken ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: a list of strings.
|
||||||
|
Returns:
|
||||||
|
a list of integers in the range [0, vocab_size)
|
||||||
|
"""
|
||||||
|
ret = []
|
||||||
|
for token in tokens:
|
||||||
|
ret.extend(self._token_to_subtoken_ids(token))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _token_to_subtoken_ids(self, token):
|
||||||
|
"""Converts token to a list of subtoken ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: a string.
|
||||||
|
Returns:
|
||||||
|
a list of integers in the range [0, vocab_size)
|
||||||
|
"""
|
||||||
|
cache_location = hash(token) % self._cache_size
|
||||||
|
cache_key, cache_value = self._cache[cache_location]
|
||||||
|
if cache_key == token:
|
||||||
|
return cache_value
|
||||||
|
ret = self._escaped_token_to_subtoken_ids(
|
||||||
|
_escape_token(token, self._alphabet))
|
||||||
|
self._cache[cache_location] = (token, ret)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _subtoken_ids_to_tokens(self, subtokens):
|
||||||
|
"""Converts a list of subtoken ids to a list of tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subtokens: a list of integers in the range [0, vocab_size)
|
||||||
|
Returns:
|
||||||
|
a list of strings.
|
||||||
|
"""
|
||||||
|
concatenated = "".join(
|
||||||
|
[self._subtoken_id_to_subtoken_string(s) for s in subtokens])
|
||||||
|
split = concatenated.split("_")
|
||||||
|
ret = []
|
||||||
|
for t in split:
|
||||||
|
if t:
|
||||||
|
unescaped = _unescape_token(t + "_")
|
||||||
|
if unescaped:
|
||||||
|
ret.append(unescaped)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _subtoken_id_to_subtoken_string(self, subtoken):
|
||||||
|
"""Converts a subtoken integer ID to a subtoken string."""
|
||||||
|
if 0 <= subtoken < self.vocab_size:
|
||||||
|
return self._all_subtoken_strings[subtoken]
|
||||||
|
return u""
|
||||||
|
|
||||||
|
def _escaped_token_to_subtoken_strings(self, escaped_token):
|
||||||
|
"""Converts an escaped token string to a list of subtoken strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
escaped_token: An escaped token as a unicode string.
|
||||||
|
Returns:
|
||||||
|
A list of subtokens as unicode strings.
|
||||||
|
"""
|
||||||
|
# NOTE: This algorithm is greedy; it won't necessarily produce the "best"
|
||||||
|
# list of subtokens.
|
||||||
|
ret = []
|
||||||
|
start = 0
|
||||||
|
token_len = len(escaped_token)
|
||||||
|
while start < token_len:
|
||||||
|
for end in range(
|
||||||
|
min(token_len, start + self._max_subtoken_len), start, -1):
|
||||||
|
subtoken = escaped_token[start:end]
|
||||||
|
if subtoken in self._subtoken_string_to_id:
|
||||||
|
ret.append(subtoken)
|
||||||
|
start = end
|
||||||
|
break
|
||||||
|
|
||||||
|
else: # Did not break
|
||||||
|
# If there is no possible encoding of the escaped token then one of the
|
||||||
|
# characters in the token is not in the alphabet. This should be
|
||||||
|
# impossible and would be indicative of a bug.
|
||||||
|
assert False, "Token substring not found in subtoken vocabulary."
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _escaped_token_to_subtoken_ids(self, escaped_token):
|
||||||
|
"""Converts an escaped token string to a list of subtoken IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
escaped_token: An escaped token as a unicode string.
|
||||||
|
Returns:
|
||||||
|
A list of subtoken IDs as integers.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
self._subtoken_string_to_id[subtoken]
|
||||||
|
for subtoken in self._escaped_token_to_subtoken_strings(escaped_token)
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_from_generator(cls,
|
||||||
|
generator,
|
||||||
|
target_size,
|
||||||
|
max_subtoken_length=None,
|
||||||
|
reserved_tokens=None):
|
||||||
|
"""Builds a SubwordTextEncoder from the generated text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generator: yields text.
|
||||||
|
target_size: int, approximate vocabulary size to create.
|
||||||
|
max_subtoken_length: Maximum length of a subtoken. If this is not set,
|
||||||
|
then the runtime and memory use of creating the vocab is quadratic in
|
||||||
|
the length of the longest token. If this is set, then it is instead
|
||||||
|
O(max_subtoken_length * length of longest token).
|
||||||
|
reserved_tokens: List of reserved tokens. The global variable
|
||||||
|
`RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this
|
||||||
|
argument is `None`, it will use `RESERVED_TOKENS`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SubwordTextEncoder with `vocab_size` approximately `target_size`.
|
||||||
|
"""
|
||||||
|
token_counts = collections.defaultdict(int)
|
||||||
|
for item in generator:
|
||||||
|
for tok in tokenizer_encode(native_to_unicode(item)):
|
||||||
|
token_counts[tok] += 1
|
||||||
|
encoder = cls.build_to_target_size(
|
||||||
|
target_size, token_counts, 1, 1e3,
|
||||||
|
max_subtoken_length=max_subtoken_length,
|
||||||
|
reserved_tokens=reserved_tokens)
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_to_target_size(cls,
|
||||||
|
target_size,
|
||||||
|
token_counts,
|
||||||
|
min_val,
|
||||||
|
max_val,
|
||||||
|
max_subtoken_length=None,
|
||||||
|
reserved_tokens=None,
|
||||||
|
num_iterations=4):
|
||||||
|
"""Builds a SubwordTextEncoder that has `vocab_size` near `target_size`.
|
||||||
|
|
||||||
|
Uses simple recursive binary search to find a minimum token count that most
|
||||||
|
closely matches the `target_size`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_size: Desired vocab_size to approximate.
|
||||||
|
token_counts: A dictionary of token counts, mapping string to int.
|
||||||
|
min_val: An integer; lower bound for the minimum token count.
|
||||||
|
max_val: An integer; upper bound for the minimum token count.
|
||||||
|
max_subtoken_length: Maximum length of a subtoken. If this is not set,
|
||||||
|
then the runtime and memory use of creating the vocab is quadratic in
|
||||||
|
the length of the longest token. If this is set, then it is instead
|
||||||
|
O(max_subtoken_length * length of longest token).
|
||||||
|
reserved_tokens: List of reserved tokens. The global variable
|
||||||
|
`RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this
|
||||||
|
argument is `None`, it will use `RESERVED_TOKENS`.
|
||||||
|
num_iterations: An integer; how many iterations of refinement.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A SubwordTextEncoder instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `min_val` is greater than `max_val`.
|
||||||
|
"""
|
||||||
|
if min_val > max_val:
|
||||||
|
raise ValueError("Lower bound for the minimum token count "
|
||||||
|
"is greater than the upper bound.")
|
||||||
|
if target_size < 1:
|
||||||
|
raise ValueError("Target size must be positive.")
|
||||||
|
|
||||||
|
if reserved_tokens is None:
|
||||||
|
reserved_tokens = RESERVED_TOKENS
|
||||||
|
|
||||||
|
def bisect(min_val, max_val):
|
||||||
|
"""Bisection to find the right size."""
|
||||||
|
present_count = (max_val + min_val) // 2
|
||||||
|
logger.info("Trying min_count %d" % present_count)
|
||||||
|
subtokenizer = cls()
|
||||||
|
subtokenizer.build_from_token_counts(
|
||||||
|
token_counts, present_count, num_iterations,
|
||||||
|
max_subtoken_length=max_subtoken_length,
|
||||||
|
reserved_tokens=reserved_tokens)
|
||||||
|
|
||||||
|
# Being within 1% of the target size is ok.
|
||||||
|
is_ok = abs(subtokenizer.vocab_size - target_size) * 100 < target_size
|
||||||
|
# If min_val == max_val, we can't do any better than this.
|
||||||
|
if is_ok or min_val >= max_val or present_count < 2:
|
||||||
|
return subtokenizer
|
||||||
|
|
||||||
|
if subtokenizer.vocab_size > target_size:
|
||||||
|
other_subtokenizer = bisect(present_count + 1, max_val)
|
||||||
|
else:
|
||||||
|
other_subtokenizer = bisect(min_val, present_count - 1)
|
||||||
|
|
||||||
|
if other_subtokenizer is None:
|
||||||
|
return subtokenizer
|
||||||
|
|
||||||
|
if (abs(other_subtokenizer.vocab_size - target_size) <
|
||||||
|
abs(subtokenizer.vocab_size - target_size)):
|
||||||
|
return other_subtokenizer
|
||||||
|
return subtokenizer
|
||||||
|
|
||||||
|
return bisect(min_val, max_val)
|
||||||
|
|
||||||
|
def build_from_token_counts(self,
|
||||||
|
token_counts,
|
||||||
|
min_count,
|
||||||
|
num_iterations=4,
|
||||||
|
reserved_tokens=None,
|
||||||
|
max_subtoken_length=None):
|
||||||
|
"""Train a SubwordTextEncoder based on a dictionary of word counts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_counts: a dictionary of Unicode strings to int.
|
||||||
|
min_count: an integer - discard subtokens with lower counts.
|
||||||
|
num_iterations: an integer. how many iterations of refinement.
|
||||||
|
reserved_tokens: List of reserved tokens. The global variable
|
||||||
|
`RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this
|
||||||
|
argument is `None`, it will use `RESERVED_TOKENS`.
|
||||||
|
max_subtoken_length: Maximum length of a subtoken. If this is not set,
|
||||||
|
then the runtime and memory use of creating the vocab is quadratic in
|
||||||
|
the length of the longest token. If this is set, then it is instead
|
||||||
|
O(max_subtoken_length * length of longest token).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it
|
||||||
|
is not clear what the space is being reserved for, or when it will be
|
||||||
|
filled in.
|
||||||
|
"""
|
||||||
|
if reserved_tokens is None:
|
||||||
|
reserved_tokens = RESERVED_TOKENS
|
||||||
|
else:
|
||||||
|
# There is not complete freedom in replacing RESERVED_TOKENS.
|
||||||
|
for default, proposed in zip(RESERVED_TOKENS, reserved_tokens):
|
||||||
|
if default != proposed:
|
||||||
|
raise ValueError("RESERVED_TOKENS must be a prefix of "
|
||||||
|
"reserved_tokens.")
|
||||||
|
|
||||||
|
# Initialize the alphabet. Note, this must include reserved tokens or it can
|
||||||
|
# result in encoding failures.
|
||||||
|
alphabet_tokens = chain(six.iterkeys(token_counts),
|
||||||
|
[native_to_unicode(t) for t in reserved_tokens])
|
||||||
|
|
||||||
|
self._init_alphabet_from_tokens(alphabet_tokens)
|
||||||
|
|
||||||
|
# Bootstrap the initial list of subtokens with the characters from the
|
||||||
|
# alphabet plus the escaping characters.
|
||||||
|
self._init_subtokens_from_list(list(self._alphabet),
|
||||||
|
reserved_tokens=reserved_tokens)
|
||||||
|
|
||||||
|
# We build iteratively. On each iteration, we segment all the words,
|
||||||
|
# then count the resulting potential subtokens, keeping the ones
|
||||||
|
# with high enough counts for our new vocabulary.
|
||||||
|
if min_count < 1:
|
||||||
|
min_count = 1
|
||||||
|
for i in range(num_iterations):
|
||||||
|
logger.info("Iteration {0}".format(i))
|
||||||
|
|
||||||
|
# Collect all substrings of the encoded token that break along current
|
||||||
|
# subtoken boundaries.
|
||||||
|
subtoken_counts = collections.defaultdict(int)
|
||||||
|
for token, count in six.iteritems(token_counts):
|
||||||
|
iter_start_time = time.time()
|
||||||
|
escaped_token = _escape_token(token, self._alphabet)
|
||||||
|
subtokens = self._escaped_token_to_subtoken_strings(escaped_token)
|
||||||
|
start = 0
|
||||||
|
for subtoken in subtokens:
|
||||||
|
last_position = len(escaped_token) + 1
|
||||||
|
if max_subtoken_length is not None:
|
||||||
|
last_position = min(last_position, start + max_subtoken_length)
|
||||||
|
|
||||||
|
for end in range(start + 1, last_position):
|
||||||
|
new_subtoken = escaped_token[start:end]
|
||||||
|
subtoken_counts[new_subtoken] += count
|
||||||
|
start += len(subtoken)
|
||||||
|
iter_time_secs = time.time() - iter_start_time
|
||||||
|
if iter_time_secs > 0.1:
|
||||||
|
logger.info(u"Processing token [{0}] took {1} seconds, consider "
|
||||||
|
"setting Text2TextProblem.max_subtoken_length to a "
|
||||||
|
"smaller value.".format(token, iter_time_secs))
|
||||||
|
|
||||||
|
# Array of sets of candidate subtoken strings, by length.
|
||||||
|
len_to_subtoken_strings = []
|
||||||
|
for subtoken_string, count in six.iteritems(subtoken_counts):
|
||||||
|
lsub = len(subtoken_string)
|
||||||
|
if count >= min_count:
|
||||||
|
while len(len_to_subtoken_strings) <= lsub:
|
||||||
|
len_to_subtoken_strings.append(set())
|
||||||
|
len_to_subtoken_strings[lsub].add(subtoken_string)
|
||||||
|
|
||||||
|
# Consider the candidates longest to shortest, so that if we accept
|
||||||
|
# a longer subtoken string, we can decrement the counts of its prefixes.
|
||||||
|
new_subtoken_strings = []
|
||||||
|
for lsub in range(len(len_to_subtoken_strings) - 1, 0, -1):
|
||||||
|
subtoken_strings = len_to_subtoken_strings[lsub]
|
||||||
|
for subtoken_string in subtoken_strings:
|
||||||
|
count = subtoken_counts[subtoken_string]
|
||||||
|
if count >= min_count:
|
||||||
|
# Exclude alphabet tokens here, as they must be included later,
|
||||||
|
# explicitly, regardless of count.
|
||||||
|
if subtoken_string not in self._alphabet:
|
||||||
|
new_subtoken_strings.append((count, subtoken_string))
|
||||||
|
for l in range(1, lsub):
|
||||||
|
subtoken_counts[subtoken_string[:l]] -= count
|
||||||
|
|
||||||
|
# Include the alphabet explicitly to guarantee all strings are encodable.
|
||||||
|
new_subtoken_strings.extend((subtoken_counts.get(a, 0), a)
|
||||||
|
for a in self._alphabet)
|
||||||
|
new_subtoken_strings.sort(reverse=True)
|
||||||
|
|
||||||
|
# Reinitialize to the candidate vocabulary.
|
||||||
|
new_subtoken_strings = [subtoken for _, subtoken in new_subtoken_strings]
|
||||||
|
if reserved_tokens:
|
||||||
|
escaped_reserved_tokens = [
|
||||||
|
_escape_token(native_to_unicode(t), self._alphabet)
|
||||||
|
for t in reserved_tokens
|
||||||
|
]
|
||||||
|
new_subtoken_strings = escaped_reserved_tokens + new_subtoken_strings
|
||||||
|
|
||||||
|
self._init_subtokens_from_list(new_subtoken_strings)
|
||||||
|
logger.info("vocab_size = %d" % self.vocab_size)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_subtoken_strings(self):
|
||||||
|
return tuple(self._all_subtoken_strings)
|
||||||
|
|
||||||
|
def dump(self):
|
||||||
|
"""Debugging dump of the current subtoken vocabulary."""
|
||||||
|
subtoken_strings = [(i, s)
|
||||||
|
for s, i in six.iteritems(self._subtoken_string_to_id)]
|
||||||
|
print(u", ".join(u"{0} : '{1}'".format(i, s)
|
||||||
|
for i, s in sorted(subtoken_strings)))
|
||||||
|
|
||||||
|
def _init_subtokens_from_list(self, subtoken_strings, reserved_tokens=None):
|
||||||
|
"""Initialize token information from a list of subtoken strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subtoken_strings: a list of subtokens
|
||||||
|
reserved_tokens: List of reserved tokens. We must have `reserved_tokens`
|
||||||
|
as None or the empty list, or else the global variable `RESERVED_TOKENS`
|
||||||
|
must be a prefix of `reserved_tokens`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it
|
||||||
|
is not clear what the space is being reserved for, or when it will be
|
||||||
|
filled in.
|
||||||
|
"""
|
||||||
|
if reserved_tokens is None:
|
||||||
|
reserved_tokens = []
|
||||||
|
|
||||||
|
if reserved_tokens:
|
||||||
|
self._all_subtoken_strings = reserved_tokens + subtoken_strings
|
||||||
|
else:
|
||||||
|
self._all_subtoken_strings = subtoken_strings
|
||||||
|
|
||||||
|
# we remember the maximum length of any subtoken to avoid having to
|
||||||
|
# check arbitrarily long strings.
|
||||||
|
self._max_subtoken_len = max([len(s) for s in subtoken_strings])
|
||||||
|
self._subtoken_string_to_id = {
|
||||||
|
s: i + len(reserved_tokens)
|
||||||
|
for i, s in enumerate(subtoken_strings) if s
|
||||||
|
}
|
||||||
|
# Initialize the cache to empty.
|
||||||
|
self._cache_size = 2 ** 20
|
||||||
|
self._cache = [(None, None)] * self._cache_size
|
||||||
|
|
||||||
|
def _init_alphabet_from_tokens(self, tokens):
|
||||||
|
"""Initialize alphabet from an iterable of token or subtoken strings."""
|
||||||
|
# Include all characters from all tokens in the alphabet to guarantee that
|
||||||
|
# any token can be encoded. Additionally, include all escaping characters.
|
||||||
|
self._alphabet = {c for token in tokens for c in token}
|
||||||
|
self._alphabet |= _ESCAPE_CHARS
|
||||||
|
|
||||||
|
def _load_from_file_object(self, f):
|
||||||
|
"""Load from a file object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f: File object to load vocabulary from
|
||||||
|
"""
|
||||||
|
subtoken_strings = []
|
||||||
|
for line in f:
|
||||||
|
s = line.strip()
|
||||||
|
# Some vocab files wrap words in single quotes, but others don't
|
||||||
|
if ((s.startswith("'") and s.endswith("'")) or
|
||||||
|
(s.startswith("\"") and s.endswith("\""))):
|
||||||
|
s = s[1:-1]
|
||||||
|
subtoken_strings.append(native_to_unicode(s))
|
||||||
|
self._init_subtokens_from_list(subtoken_strings)
|
||||||
|
self._init_alphabet_from_tokens(subtoken_strings)
|
||||||
|
|
||||||
|
def _load_from_file(self, filename):
|
||||||
|
"""Load from a vocab file."""
|
||||||
|
if not os.path.isfile(filename):
|
||||||
|
raise ValueError("File %s not found" % filename)
|
||||||
|
with open(filename) as f:
|
||||||
|
self._load_from_file_object(f)
|
||||||
|
|
||||||
|
def store_to_file(self, filename, add_single_quotes=True):
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
for subtoken_string in self._all_subtoken_strings:
|
||||||
|
if add_single_quotes:
|
||||||
|
f.write("'" + unicode_to_native(subtoken_string) + "'\n")
|
||||||
|
else:
|
||||||
|
f.write(unicode_to_native(subtoken_string) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessor():
|
||||||
|
def __init__(self, bert_vocab_path, n_ctx=512):
|
||||||
|
self.text_encoder = SubwordTextEncoder(bert_vocab_path)
|
||||||
|
self.clf_token = self.text_encoder.decode('Clf')
|
||||||
|
self.sep_token = self.text_encoder.decode('Sep')
|
||||||
|
self.A_token = self.text_encoder.decode('A')
|
||||||
|
self.B_token = self.text_encoder.decode('B')
|
||||||
|
self.n_ctx = 512
|
||||||
|
|
||||||
|
def encode_single_sentences(self, input_sentences: List[str]) -> np.array:
|
||||||
""" Prepare a torch.Tensor of inputs for BERT model from a string.
|
""" Prepare a torch.Tensor of inputs for BERT model from a string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_sentences: list of
|
input_sentences: list of single sentences (always considered as a sentence_A type)
|
||||||
- pairs of tokenized sentences (sentence_A, sentence_B) or
|
|
||||||
- tokenized sentences (will be considered as sentence_A only)
|
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Numpy array of formated inputs for BERT model
|
Numpy array of formated inputs for BERT model
|
||||||
"""
|
"""
|
||||||
batch_size = sum(min(len(x), n_perso_permute) for x in X1)
|
batch_size = len(input_sentences)
|
||||||
input_mask = np.zeros((n_batch, n_cands, n_ctx), dtype=np.float32)
|
input_array = np.zeros((batch_size, self.n_ctx, 3), dtype=np.int32)
|
||||||
input_array = np.zeros((n_batch, n_cands, n_ctx, 3), dtype=np.int32)
|
input_mask = np.zeros((batch_size, self.n_ctx), dtype=np.float32)
|
||||||
i = 0
|
i = 0
|
||||||
for tokenized_input in input_sentences:
|
for sentence in input_sentences:
|
||||||
x1j, lxj, lperso, lhisto, dialog_embed = format_transformer_input(x1, x2, xcand_j, text_encoder,
|
tokenized_sentence = self.text_encoder.encode(sentence)
|
||||||
dialog_embed_mode, max_len=max_len,
|
x1j = [self.clf_token] + tokenized_sentence
|
||||||
add_start_stop=True)
|
lxj = len(x1j)
|
||||||
lmj = len(xcand_j[:max_len]) + 1
|
input_array[i, :lxj, 0] = x1j
|
||||||
xmb[i, j, :lxj, 0] = x1j
|
input_array[i, :lxj, 0] = [self.A_token] * lxj
|
||||||
if dialog_embed_mode == 1 or dialog_embed_mode == 2:
|
input_array[i, :, 1] = np.arange(self.n_vocab+self.n_special, self.n_vocab+self.n_special+self.n_ctx)
|
||||||
xmb[i, j, :lxj, 2] = dialog_embed
|
input_mask[i, :lxj] = 1
|
||||||
mmb[i, j, :lxj] = 1
|
|
||||||
if fix_lm_index: # Take one before so we don't predict from classify token...
|
|
||||||
mmb_eval[i, j, (lxj-lmj-1):lxj-1] = 1 # This one only mask the response so we get the perplexity on the response only
|
|
||||||
else:
|
|
||||||
mmb_eval[i, j, (lxj-lmj):lxj] = 1 # This one only mask the response so we get the perplexity on the response only
|
|
||||||
xmb[i, j, :, 1] = np.arange(n_vocab+n_special, n_vocab+n_special+n_ctx)
|
|
||||||
i += 1
|
i += 1
|
||||||
return input_array, input_mask
|
return input_array, input_mask
|
||||||
18
example.py
18
example.py
@@ -1,18 +0,0 @@
|
|||||||
"""
|
|
||||||
Show how to use HuggingFace's PyTorch implementation of Google's BERT Model.
|
|
||||||
"""
|
|
||||||
from .bert_model import BERT
|
|
||||||
from .prepare_inputs import DataPreprocessor
|
|
||||||
|
|
||||||
bert_model = BERT()
|
|
||||||
bert_model.load_from('.')
|
|
||||||
data_processor = DataProcessor(encoder_file_path='.')
|
|
||||||
|
|
||||||
input_sentence = "We are playing with the BERT model."
|
|
||||||
print("BERT inputs: {}".format(input_sentence))
|
|
||||||
|
|
||||||
tensor_input = data_processor.encode(input_sentence)
|
|
||||||
tensor_output = bert_model(prepared_input)
|
|
||||||
output_sentence = data_processor.decode(tensor_output)
|
|
||||||
|
|
||||||
print("BERT predicted: {}".format(output_sentence))
|
|
||||||
Reference in New Issue
Block a user