Reorganize repo (#8580)

* Put models in subfolders

* Styling

* Fix imports in tests

* More fixes in test imports

* Sneaky hidden imports

* Fix imports in doc files

* More sneaky imports

* Finish fixing tests

* Fix examples

* Fix path for copies

* More fixes for examples

* Fix dummy files

* More fixes for example

* More model import fixes

* Is this why you're unhappy GitHub?

* Fix imports in conver command
This commit is contained in:
Sylvain Gugger
2020-11-16 21:43:42 -05:00
committed by GitHub
parent 901507335f
commit c89bdfbe72
381 changed files with 2651 additions and 1571 deletions

View File

@@ -0,0 +1,48 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
if is_tokenizers_available():
from .tokenization_bert_fast import BertTokenizerFast
if is_torch_available():
from .modeling_bert import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertLayer,
BertLMHeadModel,
BertModel,
BertPreTrainedModel,
load_tf_weights_in_bert,
)
if is_tf_available():
from .modeling_tf_bert import (
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBertEmbeddings,
TFBertForMaskedLM,
TFBertForMultipleChoice,
TFBertForNextSentencePrediction,
TFBertForPreTraining,
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
TFBertForTokenClassification,
TFBertLMHeadModel,
TFBertMainLayer,
TFBertModel,
TFBertPreTrainedModel,
)
if is_flax_available():
from .modeling_flax_bert import FlaxBertModel

View File

@@ -0,0 +1,142 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BERT model configuration """
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/config.json",
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/config.json",
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/config.json",
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/config.json",
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/config.json",
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/config.json",
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/config.json",
"bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/config.json",
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/config.json",
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/config.json",
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/config.json",
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json",
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/config.json",
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/config.json",
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/config.json",
"cl-tohoku/bert-base-japanese": "https://huggingface.co/cl-tohoku/bert-base-japanese/resolve/main/config.json",
"cl-tohoku/bert-base-japanese-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking/resolve/main/config.json",
"cl-tohoku/bert-base-japanese-char": "https://huggingface.co/cl-tohoku/bert-base-japanese-char/resolve/main/config.json",
"cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking/resolve/main/config.json",
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/config.json",
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/config.json",
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/config.json",
# See all BERT models at https://huggingface.co/models?filter=bert
}
class BertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.BertModel` or a
:class:`~transformers.TFBertModel`. It is used to instantiate a BERT model according to the specified arguments,
defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
to that of the BERT `bert-base-uncased <https://huggingface.co/bert-base-uncased>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 30522):
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or
:class:`~transformers.TFBertModel`.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or
:class:`~transformers.TFBertModel`.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Examples::
>>> from transformers import BertModel, BertConfig
>>> # Initializing a BERT bert-base-uncased style configuration
>>> configuration = BertConfig()
>>> # Initializing a model from the bert-base-uncased style configuration
>>> model = BertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "bert"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing

View File

@@ -0,0 +1,226 @@
"""
This script can be used to convert a head-less TF2.x Bert model to PyTorch, as published on the official GitHub:
https://github.com/tensorflow/models/tree/master/official/nlp/bert
TF2.x uses different variable names from the original BERT (TF 1.4) implementation. The script re-maps the TF2.x Bert
weight names to the original names, so the model can be imported with Huggingface/transformer.
You may adapt this script to include classification/MLM/NSP/etc. heads.
"""
import argparse
import os
import re
import tensorflow as tf
import torch
from transformers import BertConfig, BertModel
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
layer_depth = []
for full_name, shape in init_vars:
# logger.info("Loading TF weight {} with shape {}".format(name, shape))
name = full_name.split("/")
if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]:
logger.info(f"Skipping non-model layer {full_name}")
continue
if "optimizer" in full_name:
logger.info(f"Skipping optimization layer {full_name}")
continue
if name[0] == "model":
# ignore initial 'model'
name = name[1:]
# figure out how many levels deep the name is
depth = 0
for _name in name:
if _name.startswith("layer_with_weights"):
depth += 1
else:
break
layer_depth.append(depth)
# read data
array = tf.train.load_variable(tf_path, full_name)
names.append("/".join(name))
arrays.append(array)
logger.info(f"Read a total of {len(arrays):,} layers")
# Sanity check
if len(set(layer_depth)) != 1:
raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})")
layer_depth = list(set(layer_depth))[0]
if layer_depth != 1:
raise ValueError(
"The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP heads."
)
# convert layers
logger.info("Converting weights...")
for full_name, array in zip(names, arrays):
name = full_name.split("/")
pointer = model
trace = []
for i, m_name in enumerate(name):
if m_name == ".ATTRIBUTES":
# variable names end with .ATTRIBUTES/VARIABLE_VALUE
break
if m_name.startswith("layer_with_weights"):
layer_num = int(m_name.split("-")[-1])
if layer_num <= 2:
# embedding layers
# layer_num 0: word_embeddings
# layer_num 1: position_embeddings
# layer_num 2: token_type_embeddings
continue
elif layer_num == 3:
# embedding LayerNorm
trace.extend(["embeddings", "LayerNorm"])
pointer = getattr(pointer, "embeddings")
pointer = getattr(pointer, "LayerNorm")
elif layer_num > 3 and layer_num < config.num_hidden_layers + 4:
# encoder layers
trace.extend(["encoder", "layer", str(layer_num - 4)])
pointer = getattr(pointer, "encoder")
pointer = getattr(pointer, "layer")
pointer = pointer[layer_num - 4]
elif layer_num == config.num_hidden_layers + 4:
# pooler layer
trace.extend(["pooler", "dense"])
pointer = getattr(pointer, "pooler")
pointer = getattr(pointer, "dense")
elif m_name == "embeddings":
trace.append("embeddings")
pointer = getattr(pointer, "embeddings")
if layer_num == 0:
trace.append("word_embeddings")
pointer = getattr(pointer, "word_embeddings")
elif layer_num == 1:
trace.append("position_embeddings")
pointer = getattr(pointer, "position_embeddings")
elif layer_num == 2:
trace.append("token_type_embeddings")
pointer = getattr(pointer, "token_type_embeddings")
else:
raise ValueError("Unknown embedding layer with name {full_name}")
trace.append("weight")
pointer = getattr(pointer, "weight")
elif m_name == "_attention_layer":
# self-attention layer
trace.extend(["attention", "self"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "self")
elif m_name == "_attention_layer_norm":
# output attention norm
trace.extend(["attention", "output", "LayerNorm"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "LayerNorm")
elif m_name == "_attention_output_dense":
# output attention dense
trace.extend(["attention", "output", "dense"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "dense")
elif m_name == "_output_dense":
# output dense
trace.extend(["output", "dense"])
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "dense")
elif m_name == "_output_layer_norm":
# output dense
trace.extend(["output", "LayerNorm"])
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "LayerNorm")
elif m_name == "_key_dense":
# attention key
trace.append("key")
pointer = getattr(pointer, "key")
elif m_name == "_query_dense":
# attention query
trace.append("query")
pointer = getattr(pointer, "query")
elif m_name == "_value_dense":
# attention value
trace.append("value")
pointer = getattr(pointer, "value")
elif m_name == "_intermediate_dense":
# attention intermediate dense
trace.extend(["intermediate", "dense"])
pointer = getattr(pointer, "intermediate")
pointer = getattr(pointer, "dense")
elif m_name == "_output_layer_norm":
# output layer norm
trace.append("output")
pointer = getattr(pointer, "output")
# weights & biases
elif m_name in ["bias", "beta"]:
trace.append("bias")
pointer = getattr(pointer, "bias")
elif m_name in ["kernel", "gamma"]:
trace.append("weight")
pointer = getattr(pointer, "weight")
else:
logger.warning(f"Ignored {m_name}")
# for certain layers reshape is necessary
trace = ".".join(trace)
if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match(
r"(\S+)\.attention\.output\.dense\.weight", trace
):
array = array.reshape(pointer.data.shape)
if "kernel" in full_name:
array = array.transpose()
if pointer.shape == array.shape:
pointer.data = torch.from_numpy(array)
else:
raise ValueError(
f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape: {array.shape}"
)
logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}")
return model
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path):
# Instantiate model
logger.info(f"Loading model based on config from {config_path}...")
config = BertConfig.from_json_file(config_path)
model = BertModel(config)
# Load weights from checkpoint
logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...")
load_tf2_weights_in_bert(model, tf_checkpoint_path, config)
# Save pytorch-model
logger.info(f"Saving PyTorch model to {pytorch_dump_path}...")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path."
)
parser.add_argument(
"--bert_config_file",
type=str,
required=True,
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path",
type=str,
required=True,
help="Path to the output PyTorch model (must include filename).",
)
args = parser.parse_args()
convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

View File

@@ -0,0 +1,61 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
import argparse
import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--bert_config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

View File

@@ -0,0 +1,113 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
import argparse
import os
import numpy as np
import tensorflow as tf
import torch
from transformers import BertModel
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
"""
Args:
model: BertModel Pytorch model instance to be converted
ckpt_dir: Tensorflow model directory
model_name: model name
Currently supported HF models:
- Y BertModel
- N BertForMaskedLM
- N BertForPreTraining
- N BertForMultipleChoice
- N BertForNextSentencePrediction
- N BertForSequenceClassification
- N BertForQuestionAnswering
"""
tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
var_map = (
("layer.", "layer_"),
("word_embeddings.weight", "word_embeddings"),
("position_embeddings.weight", "position_embeddings"),
("token_type_embeddings.weight", "token_type_embeddings"),
(".", "/"),
("LayerNorm/weight", "LayerNorm/gamma"),
("LayerNorm/bias", "LayerNorm/beta"),
("weight", "kernel"),
)
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
state_dict = model.state_dict()
def to_tf_var_name(name: str):
for patt, repl in iter(var_map):
name = name.replace(patt, repl)
return "bert/{}".format(name)
def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
session.run(tf.variables_initializer([tf_var]))
session.run(tf_var)
return tf_var
tf.reset_default_graph()
with tf.Session() as session:
for var_name in state_dict:
tf_name = to_tf_var_name(var_name)
torch_tensor = state_dict[var_name].numpy()
if any([x in var_name for x in tensors_to_transpose]):
torch_tensor = torch_tensor.T
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
tf.keras.backend.set_value(tf_var, torch_tensor)
tf_weight = session.run(tf_var)
print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))
saver = tf.train.Saver(tf.trainable_variables())
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
def main(raw_args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True, help="model name e.g. bert-base-uncased")
parser.add_argument(
"--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
)
parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
args = parser.parse_args(raw_args)
model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.model_name,
state_dict=torch.load(args.pytorch_model_path),
cache_dir=args.cache_dir,
)
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,419 @@
# coding=utf-8
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
from ...utils import logging
from .configuration_bert import BertConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer"
BERT_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)
This model is also a Flax Linen `flax.nn.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
Parameters:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
"""
BERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``:
- 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token.
`What are token type IDs? <../glossary.html#token-type-ids>`__
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
class FlaxBertLayerNorm(nn.Module):
"""
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
"""
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True # If True, bias (beta) is added.
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
bias_init: jnp.ndarray = nn.initializers.zeros
scale_init: jnp.ndarray = nn.initializers.ones
@nn.compact
def __call__(self, x):
"""
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that
maintains the mean activation within each example close to 0 and the activation standard deviation close to 1
Args:
x: the inputs
Returns:
Normalized inputs (the same shape as inputs).
"""
features = x.shape[-1]
mean = jnp.mean(x, axis=-1, keepdims=True)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
var = mean2 - jax.lax.square(mean)
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
y = (x - mean) * mul
if self.bias:
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
return y
class FlaxBertEmbedding(nn.Module):
"""
Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch
use 'weight'
"""
vocab_size: int
hidden_size: int
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
@nn.compact
def __call__(self, inputs):
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
return jnp.take(embedding, inputs, axis=0)
class FlaxBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embed
w_emb = FlaxBertEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
jnp.atleast_2d(input_ids.astype("i4"))
)
p_emb = FlaxBertEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
jnp.atleast_2d(position_ids.astype("i4"))
)
t_emb = FlaxBertEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
jnp.atleast_2d(token_type_ids.astype("i4"))
)
# Sum all embeddings
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
# Layer Norm
layer_norm = FlaxBertLayerNorm(name="layer_norm")(summed_emb)
return layer_norm
class FlaxBertAttention(nn.Module):
num_heads: int
head_size: int
@nn.compact
def __call__(self, hidden_state, attention_mask):
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask
)
layer_norm = FlaxBertLayerNorm(name="layer_norm")(self_att + hidden_state)
return layer_norm
class FlaxBertIntermediate(nn.Module):
output_size: int
@nn.compact
def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
return gelu(dense)
class FlaxBertOutput(nn.Module):
@nn.compact
def __call__(self, intermediate_output, attention_output):
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
hidden_state = FlaxBertLayerNorm(name="layer_norm")(hidden_state + attention_output)
return hidden_state
class FlaxBertLayer(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
@nn.compact
def __call__(self, hidden_state, attention_mask):
attention = FlaxBertAttention(self.num_heads, self.head_size, name="attention")(hidden_state, attention_mask)
intermediate = FlaxBertIntermediate(self.intermediate_size, name="intermediate")(attention)
output = FlaxBertOutput(name="output")(intermediate, attention)
return output
class FlaxBertLayerCollection(nn.Module):
"""
Stores N BertLayer(s)
"""
num_layers: int
num_heads: int
head_size: int
intermediate_size: int
@nn.compact
def __call__(self, inputs, attention_mask):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
# Initialize input / output
input_i = inputs
# Forward over all encoders
for i in range(self.num_layers):
layer = FlaxBertLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}")
input_i = layer(input_i, attention_mask)
return input_i
class FlaxBertEncoder(nn.Module):
num_layers: int
num_heads: int
head_size: int
intermediate_size: int
@nn.compact
def __call__(self, hidden_state, attention_mask):
layer = FlaxBertLayerCollection(
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
)(hidden_state, attention_mask)
return layer
class FlaxBertPooler(nn.Module):
@nn.compact
def __call__(self, hidden_state):
cls_token = hidden_state[:, 0]
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
return jax.lax.tanh(out)
class FlaxBertModule(nn.Module):
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
num_encoder_layers: int
num_heads: int
head_size: int
intermediate_size: int
@nn.compact
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids):
# Embedding
embeddings = FlaxBertEmbeddings(
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
)(input_ids, token_type_ids, position_ids, attention_mask)
# N stacked encoding layers
encoder = FlaxBertEncoder(
self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder"
)(embeddings, attention_mask)
pooled = FlaxBertPooler(name="pooler")(encoder)
return encoder, pooled
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class FlaxBertModel(FlaxPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
"""
model_class = FlaxBertModule
config_class = BertConfig
base_model_prefix = "bert"
@staticmethod
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
jax_state = dict(pt_state)
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
for key, tensor in pt_state.items():
# Key parts
key_parts = set(key.split("."))
# Every dense layer has "kernel" parameters instead of "weight"
if "dense.weight" in key:
del jax_state[key]
key = key.replace("weight", "kernel")
jax_state[key] = tensor
# SelfAttention needs also to replace "weight" by "kernel"
if {"query", "key", "value"} & key_parts:
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
if "bias" in key:
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
elif "weight":
del jax_state[key]
key = key.replace("weight", "kernel")
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove one nesting
if "attention.output.dense" in key:
del jax_state[key]
key = key.replace("attention.output.dense", "attention.self.out")
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove nesting on layer norm
if "attention.output.LayerNorm" in key:
del jax_state[key]
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
jax_state[key] = tensor
# There are some transposed parameters w.r.t their PyTorch counterpart
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
jax_state[key] = tensor.T
# Self Attention output projection needs to be transposed
if "out.kernel" in key:
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
1, 2, 0
)
# Pooler needs to transpose its kernel
if "pooler.dense.kernel" in key:
jax_state[key] = tensor.T
# Handle LayerNorm conversion
if "LayerNorm" in key:
del jax_state[key]
# Replace LayerNorm by layer_norm
new_key = key.replace("LayerNorm", "layer_norm")
if "weight" in key:
new_key = new_key.replace("weight", "gamma")
elif "bias" in key:
new_key = new_key.replace("bias", "beta")
jax_state[new_key] = tensor
return jax_state
def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs):
model = FlaxBertModule(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
type_vocab_size=config.type_vocab_size,
max_length=config.max_position_embeddings,
num_encoder_layers=config.num_hidden_layers,
num_heads=config.num_attention_heads,
head_size=config.hidden_size,
intermediate_size=config.intermediate_size,
)
super().__init__(config, model, state, seed)
@property
def module(self) -> nn.Module:
return self._module
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
return self.model.apply(
{"params": self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,558 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Bert."""
import collections
import os
import unicodedata
from typing import List, Optional, Tuple
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from ...utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
"bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"bert-base-uncased": 512,
"bert-large-uncased": 512,
"bert-base-cased": 512,
"bert-large-cased": 512,
"bert-base-multilingual-uncased": 512,
"bert-base-multilingual-cased": 512,
"bert-base-chinese": 512,
"bert-base-german-cased": 512,
"bert-large-uncased-whole-word-masking": 512,
"bert-large-cased-whole-word-masking": 512,
"bert-large-uncased-whole-word-masking-finetuned-squad": 512,
"bert-large-cased-whole-word-masking-finetuned-squad": 512,
"bert-base-cased-finetuned-mrpc": 512,
"bert-base-german-dbmdz-cased": 512,
"bert-base-german-dbmdz-uncased": 512,
"TurkuNLP/bert-base-finnish-cased-v1": 512,
"TurkuNLP/bert-base-finnish-uncased-v1": 512,
"wietsedv/bert-base-dutch-cased": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"bert-base-uncased": {"do_lower_case": True},
"bert-large-uncased": {"do_lower_case": True},
"bert-base-cased": {"do_lower_case": False},
"bert-large-cased": {"do_lower_case": False},
"bert-base-multilingual-uncased": {"do_lower_case": True},
"bert-base-multilingual-cased": {"do_lower_case": False},
"bert-base-chinese": {"do_lower_case": False},
"bert-base-german-cased": {"do_lower_case": False},
"bert-large-uncased-whole-word-masking": {"do_lower_case": True},
"bert-large-cased-whole-word-masking": {"do_lower_case": False},
"bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
"bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
"bert-base-german-dbmdz-cased": {"do_lower_case": False},
"bert-base-german-dbmdz-uncased": {"do_lower_case": True},
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
}
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n")
vocab[token] = index
return vocab
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class BertTokenizer(PreTrainedTokenizer):
r"""
Construct a BERT tokenizer. Based on WordPiece.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
Users should refer to this superclass for more information regarding those methods.
Args:
vocab_file (:obj:`str`):
File containing the vocabulary.
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to lowercase the input when tokenizing.
do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to do basic tokenization before WordPiece.
never_split (:obj:`Iterable`, `optional`):
Collection of tokens which will never be split during tokenization. Only has an effect when
:obj:`do_basic_tokenize=True`
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
The token used for padding, for example when batching sequences of different lengths.
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to tokenize Chinese characters.
This should likely be deactivated for Japanese (see this `issue
<https://github.com/huggingface/transformers/issues/328>`__).
strip_accents: (:obj:`bool`, `optional`):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for :obj:`lowercase` (as in the original BERT).
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
**kwargs
):
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
@property
def do_lower_case(self):
return self.basic_tokenizer.do_lower_case
@property
def vocab_size(self):
return len(self.vocab)
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
else:
split_tokens += self.wordpiece_tokenizer.tokenize(token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: ``[CLS] X [SEP]``
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
pair mask has the following format:
::
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
sequence(s).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
else:
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
"Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!".format(vocab_file)
)
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,)
class BasicTokenizer(object):
"""
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
Args:
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to lowercase the input when tokenizing.
never_split (:obj:`Iterable`, `optional`):
Collection of tokens which will never be split during tokenization. Only has an effect when
:obj:`do_basic_tokenize=True`
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to tokenize Chinese characters.
This should likely be deactivated for Japanese (see this `issue
<https://github.com/huggingface/transformers/issues/328>`__).
strip_accents: (:obj:`bool`, `optional`):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for :obj:`lowercase` (as in the original BERT).
"""
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
if never_split is None:
never_split = []
self.do_lower_case = do_lower_case
self.never_split = set(never_split)
self.tokenize_chinese_chars = tokenize_chinese_chars
self.strip_accents = strip_accents
def tokenize(self, text, never_split=None):
"""
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
WordPieceTokenizer.
Args:
**never_split**: (`optional`) list of str
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
:func:`PreTrainedTokenizer.tokenize`) List of token not to split.
"""
# union() returns a new set by concatenating the two sets.
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
if self.tokenize_chinese_chars:
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if token not in never_split:
if self.do_lower_case:
token = token.lower()
if self.strip_accents is not False:
token = self._run_strip_accents(token)
elif self.strip_accents:
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token, never_split))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text."""
if never_split is not None and text in never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xFFFD or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
tokenization using the given vocabulary.
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer`.
Returns:
A list of wordpiece tokens.
"""
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens

View File

@@ -0,0 +1,259 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Tokenization classes for Bert."""
import json
from typing import List, Optional, Tuple
from tokenizers import normalizers
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from .tokenization_bert import BertTokenizer
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
"bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
},
"tokenizer_file": {
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/tokenizer.json",
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/tokenizer.json",
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/tokenizer.json",
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/tokenizer.json",
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/tokenizer.json",
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/tokenizer.json",
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/tokenizer.json",
"bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-tokenizer.json",
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/tokenizer.json",
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/tokenizer.json",
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/tokenizer.json",
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/tokenizer.json",
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/tokenizer.json",
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/tokenizer.json",
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/tokenizer.json",
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/tokenizer.json",
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/tokenizer.json",
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/tokenizer.json",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"bert-base-uncased": 512,
"bert-large-uncased": 512,
"bert-base-cased": 512,
"bert-large-cased": 512,
"bert-base-multilingual-uncased": 512,
"bert-base-multilingual-cased": 512,
"bert-base-chinese": 512,
"bert-base-german-cased": 512,
"bert-large-uncased-whole-word-masking": 512,
"bert-large-cased-whole-word-masking": 512,
"bert-large-uncased-whole-word-masking-finetuned-squad": 512,
"bert-large-cased-whole-word-masking-finetuned-squad": 512,
"bert-base-cased-finetuned-mrpc": 512,
"bert-base-german-dbmdz-cased": 512,
"bert-base-german-dbmdz-uncased": 512,
"TurkuNLP/bert-base-finnish-cased-v1": 512,
"TurkuNLP/bert-base-finnish-uncased-v1": 512,
"wietsedv/bert-base-dutch-cased": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"bert-base-uncased": {"do_lower_case": True},
"bert-large-uncased": {"do_lower_case": True},
"bert-base-cased": {"do_lower_case": False},
"bert-large-cased": {"do_lower_case": False},
"bert-base-multilingual-uncased": {"do_lower_case": True},
"bert-base-multilingual-cased": {"do_lower_case": False},
"bert-base-chinese": {"do_lower_case": False},
"bert-base-german-cased": {"do_lower_case": False},
"bert-large-uncased-whole-word-masking": {"do_lower_case": True},
"bert-large-cased-whole-word-masking": {"do_lower_case": False},
"bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
"bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
"bert-base-german-dbmdz-cased": {"do_lower_case": False},
"bert-base-german-dbmdz-uncased": {"do_lower_case": True},
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
}
class BertTokenizerFast(PreTrainedTokenizerFast):
r"""
Construct a "fast" BERT tokenizer (backed by HuggingFace's `tokenizers` library). Based on WordPiece.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
methods. Users should refer to this superclass for more information regarding those methods.
Args:
vocab_file (:obj:`str`):
File containing the vocabulary.
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to lowercase the input when tokenizing.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
The token used for padding, for example when batching sequences of different lengths.
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
clean_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to clean the text before tokenization by removing any control characters and replacing all
whitespaces by the classic one.
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see `this
issue <https://github.com/huggingface/transformers/issues/328>`__).
strip_accents: (:obj:`bool`, `optional`):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for :obj:`lowercase` (as in the original BERT).
wordpieces_prefix: (:obj:`str`, `optional`, defaults to :obj:`"##"`):
The prefix for subwords.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
slow_tokenizer_class = BertTokenizer
def __init__(
self,
vocab_file,
tokenizer_file=None,
do_lower_case=True,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
**kwargs
):
super().__init__(
vocab_file,
tokenizer_file=tokenizer_file,
do_lower_case=do_lower_case,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
if (
pre_tok_state.get("do_lower_case", do_lower_case) != do_lower_case
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
):
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
pre_tok_state["do_lower_case"] = do_lower_case
pre_tok_state["strip_accents"] = strip_accents
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
self.do_lower_case = do_lower_case
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: ``[CLS] X [SEP]``
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1:
output += token_ids_1 + [self.sep_token_id]
return output
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
pair mask has the following format:
::
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
sequence(s).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)