Reformat source code with black.
This is the result of:
$ black --line-length 119 examples templates transformers utils hubconf.py setup.py
There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.
This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
This commit is contained in:
@@ -27,8 +27,8 @@ from .configuration_utils import PretrainedConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
XXX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-config.json",
|
||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-config.json",
|
||||
"xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-config.json",
|
||||
"xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-config.json",
|
||||
}
|
||||
|
||||
|
||||
@@ -63,24 +63,26 @@ class XxxConfig(PretrainedConfig):
|
||||
"""
|
||||
pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size=50257,
|
||||
n_positions=1024,
|
||||
n_ctx=1024,
|
||||
n_embd=768,
|
||||
n_layer=12,
|
||||
n_head=12,
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
summary_type='cls_index',
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
summary_first_dropout=0.1,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50257,
|
||||
n_positions=1024,
|
||||
n_ctx=1024,
|
||||
n_embd=768,
|
||||
n_layer=12,
|
||||
n_head=12,
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
summary_type="cls_index",
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
summary_first_dropout=0.1,
|
||||
**kwargs
|
||||
):
|
||||
super(XxxConfig, self).__init__(**kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.n_ctx = n_ctx
|
||||
|
||||
@@ -24,8 +24,10 @@ import torch
|
||||
from transformers import XxxConfig, XxxForPreTraining, load_tf_weights_in_xxx
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = XxxConfig.from_json_file(config_file)
|
||||
@@ -43,23 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||
args.config_file,
|
||||
args.pytorch_dump_path)
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
||||
|
||||
@@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
|
||||
# for the pretrained weights provided with the models
|
||||
####################################################
|
||||
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-tf_model.h5",
|
||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-tf_model.h5",
|
||||
"xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-tf_model.h5",
|
||||
"xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-tf_model.h5",
|
||||
}
|
||||
|
||||
####################################################
|
||||
@@ -69,9 +69,9 @@ TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
class TFXxxLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super(TFXxxLayer, self).__init__(**kwargs)
|
||||
self.attention = TFXxxAttention(config, name='attention')
|
||||
self.intermediate = TFXxxIntermediate(config, name='intermediate')
|
||||
self.transformer_output = TFXxxOutput(config, name='output')
|
||||
self.attention = TFXxxAttention(config, name="attention")
|
||||
self.intermediate = TFXxxIntermediate(config, name="intermediate")
|
||||
self.transformer_output = TFXxxOutput(config, name="output")
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask = inputs
|
||||
@@ -98,7 +98,9 @@ class TFXxxMainLayer(tf.keras.layers.Layer):
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
|
||||
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||
def call(
|
||||
self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False
|
||||
):
|
||||
# We allow three types of multi-inputs:
|
||||
# - traditional keyword arguments in the call method
|
||||
# - all the arguments provided as a dict in the first positional argument of call
|
||||
@@ -113,11 +115,11 @@ class TFXxxMainLayer(tf.keras.layers.Layer):
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
position_ids = inputs.get('position_ids', position_ids)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
@@ -175,6 +177,7 @@ class TFXxxPreTrainedModel(TFPreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = XxxConfig
|
||||
pretrained_model_archive_map = TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
@@ -263,8 +266,12 @@ XXX_INPUTS_DOCSTRING = r"""
|
||||
than the model's internal embedding lookup matrix.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class TFXxxModel(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -297,17 +304,19 @@ class TFXxxModel(TFXxxPreTrainedModel):
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxModel, self).__init__(config, *inputs, **kwargs)
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a `language modeling` head on top. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model with a `language modeling` head on top. """, XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING
|
||||
)
|
||||
class TFXxxForMaskedLM(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -333,26 +342,30 @@ class TFXxxForMaskedLM(TFXxxPreTrainedModel):
|
||||
prediction_scores = outputs[0]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxForMaskedLM, self).__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
self.mlm = TFXxxMLMHead(config, self.transformer.embeddings, name='mlm')
|
||||
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||
self.mlm = TFXxxMLMHead(config, self.transformer.embeddings, name="mlm")
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output, training=kwargs.get('training', False))
|
||||
prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
|
||||
|
||||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
||||
|
||||
return outputs # prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -378,22 +391,23 @@ class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
|
||||
logits = outputs[0]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxForSequenceClassification, self).__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name='classifier')
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
)
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False))
|
||||
pooled_output = self.dropout(pooled_output, training=kwargs.get("training", False))
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
@@ -401,9 +415,12 @@ class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
|
||||
return outputs # logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a token classification head on top (a linear layer on top of
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class TFXxxForTokenClassification(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -429,22 +446,23 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel):
|
||||
scores = outputs[0]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxForTokenClassification, self).__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name='classifier')
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
)
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output, training=kwargs.get('training', False))
|
||||
sequence_output = self.dropout(sequence_output, training=kwargs.get("training", False))
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
@@ -452,9 +470,12 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel):
|
||||
return outputs # scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class TFXxxForQuestionAnswering(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -482,14 +503,15 @@ class TFXxxForQuestionAnswering(TFXxxPreTrainedModel):
|
||||
start_scores, end_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxForQuestionAnswering, self).__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
self.qa_outputs = tf.keras.layers.Dense(config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name='qa_outputs')
|
||||
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
)
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
|
||||
@@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
|
||||
# for the pretrained weights provided with the models
|
||||
####################################################
|
||||
XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-pytorch_model.bin",
|
||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-pytorch_model.bin",
|
||||
"xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-pytorch_model.bin",
|
||||
"xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-pytorch_model.bin",
|
||||
}
|
||||
|
||||
####################################################
|
||||
@@ -60,8 +60,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
logger.error(
|
||||
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
@@ -76,7 +78,7 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
||||
arrays.append(array)
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name.split('/')
|
||||
name = name.split("/")
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
||||
@@ -84,18 +86,18 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
l = re.split(r'_(\d+)', m_name)
|
||||
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
||||
l = re.split(r"_(\d+)", m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'output_weights':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'squad':
|
||||
pointer = getattr(pointer, 'classifier')
|
||||
if l[0] == "kernel" or l[0] == "gamma":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif l[0] == "output_bias" or l[0] == "beta":
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif l[0] == "output_weights":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif l[0] == "squad":
|
||||
pointer = getattr(pointer, "classifier")
|
||||
else:
|
||||
try:
|
||||
pointer = getattr(pointer, l[0])
|
||||
@@ -105,9 +107,9 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
if m_name[-11:] == '_embeddings':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif m_name == 'kernel':
|
||||
if m_name[-11:] == "_embeddings":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif m_name == "kernel":
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
@@ -147,7 +149,6 @@ class XxxLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
####################################################
|
||||
# PreTrainedModel is a sub-class of torch.nn.Module
|
||||
# which take care of loading and saving pretrained weights
|
||||
@@ -161,6 +162,7 @@ class XxxPreTrainedModel(PreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = XxxConfig
|
||||
pretrained_model_archive_map = XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_xxx
|
||||
@@ -246,8 +248,12 @@ XXX_INPUTS_DOCSTRING = r"""
|
||||
than the model's internal embedding lookup matrix.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class XxxModel(XxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -277,6 +283,7 @@ class XxxModel(XxxPreTrainedModel):
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(XxxModel, self).__init__(config)
|
||||
|
||||
@@ -300,7 +307,15 @@ class XxxModel(XxxPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@@ -329,7 +344,7 @@ class XxxModel(XxxPreTrainedModel):
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
@@ -342,14 +357,20 @@ class XxxModel(XxxPreTrainedModel):
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
head_mask = (
|
||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
||||
) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(
|
||||
dtype=next(self.parameters()).dtype
|
||||
) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
##################################
|
||||
# Replace this with your model code
|
||||
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||
)
|
||||
encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask)
|
||||
sequence_output = encoder_outputs[0]
|
||||
outputs = (sequence_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
@@ -357,8 +378,9 @@ class XxxModel(XxxPreTrainedModel):
|
||||
return outputs # sequence_output, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a `language modeling` head on top. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model with a `language modeling` head on top. """, XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING
|
||||
)
|
||||
class XxxForMaskedLM(XxxPreTrainedModel):
|
||||
r"""
|
||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -389,6 +411,7 @@ class XxxForMaskedLM(XxxPreTrainedModel):
|
||||
loss, prediction_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(XxxForMaskedLM, self).__init__(config)
|
||||
|
||||
@@ -400,15 +423,25 @@ class XxxForMaskedLM(XxxPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||
masked_lm_labels=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
masked_lm_labels=None,
|
||||
):
|
||||
|
||||
outputs = self.transformer(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.cls(sequence_output)
|
||||
@@ -422,9 +455,12 @@ class XxxForMaskedLM(XxxPreTrainedModel):
|
||||
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class XxxForSequenceClassification(XxxPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -456,6 +492,7 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(XxxForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@@ -466,15 +503,25 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
|
||||
outputs = self.transformer(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
@@ -496,9 +543,12 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a token classification head on top (a linear layer on top of
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class XxxForTokenClassification(XxxPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -528,6 +578,7 @@ class XxxForTokenClassification(XxxPreTrainedModel):
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(XxxForTokenClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@@ -538,15 +589,25 @@ class XxxForTokenClassification(XxxPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
|
||||
outputs = self.transformer(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
@@ -569,9 +630,12 @@ class XxxForTokenClassification(XxxPreTrainedModel):
|
||||
return outputs # (loss), scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class XxxForQuestionAnswering(XxxPreTrainedModel):
|
||||
r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -613,6 +677,7 @@ class XxxForQuestionAnswering(XxxPreTrainedModel):
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(XxxForQuestionAnswering, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@@ -622,15 +687,26 @@ class XxxForQuestionAnswering(XxxPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||
start_positions=None, end_positions=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
):
|
||||
|
||||
outputs = self.transformer(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from __future__ import print_function
|
||||
import unittest
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .modeling_tf_common_test import TFCommonTestCases, ids_tensor
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
@@ -27,46 +27,57 @@ from transformers import XxxConfig, is_tf_available
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from transformers.modeling_tf_xxx import (TFXxxModel, TFXxxForMaskedLM,
|
||||
TFXxxForSequenceClassification,
|
||||
TFXxxForTokenClassification,
|
||||
TFXxxForQuestionAnswering,
|
||||
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from transformers.modeling_tf_xxx import (
|
||||
TFXxxModel,
|
||||
TFXxxForMaskedLM,
|
||||
TFXxxForSequenceClassification,
|
||||
TFXxxForTokenClassification,
|
||||
TFXxxForQuestionAnswering,
|
||||
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
all_model_classes = (TFXxxModel, TFXxxForMaskedLM, TFXxxForQuestionAnswering,
|
||||
TFXxxForSequenceClassification,
|
||||
TFXxxForTokenClassification) if is_tf_available() else ()
|
||||
all_model_classes = (
|
||||
(
|
||||
TFXxxModel,
|
||||
TFXxxForMaskedLM,
|
||||
TFXxxForQuestionAnswering,
|
||||
TFXxxForSequenceClassification,
|
||||
TFXxxForTokenClassification,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
class TFXxxModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
@@ -120,15 +131,16 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range)
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_xxx_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFXxxModel(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
sequence_output, pooled_output = model(inputs)
|
||||
|
||||
inputs = [input_ids, input_mask]
|
||||
@@ -141,78 +153,74 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
"pooled_output": pooled_output.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].shape),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size])
|
||||
|
||||
|
||||
def create_and_check_xxx_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFXxxForMaskedLM(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
prediction_scores, = model(inputs)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(prediction_scores,) = model(inputs)
|
||||
result = {
|
||||
"prediction_scores": prediction_scores.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].shape),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
|
||||
|
||||
def create_and_check_xxx_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFXxxForSequenceClassification(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
logits, = model(inputs)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(logits,) = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape),
|
||||
[self.batch_size, self.num_labels])
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||
|
||||
|
||||
def create_and_check_xxx_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFXxxForTokenClassification(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
logits, = model(inputs)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(logits,) = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape),
|
||||
[self.batch_size, self.seq_length, self.num_labels])
|
||||
list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels]
|
||||
)
|
||||
|
||||
|
||||
def create_and_check_xxx_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFXxxForQuestionAnswering(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
start_logits, end_logits = model(inputs)
|
||||
result = {
|
||||
"start_logits": start_logits.numpy(),
|
||||
"end_logits": end_logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].shape),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].shape),
|
||||
[self.batch_size, self.seq_length])
|
||||
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, token_type_ids, input_mask,
|
||||
sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
@@ -244,9 +252,10 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in ['xxx-base-uncased']:
|
||||
for model_name in ["xxx-base-uncased"]:
|
||||
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -20,51 +20,60 @@ import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .modeling_common_test import CommonTestCases, ids_tensor
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (XxxConfig, XxxModel, XxxForMaskedLM,
|
||||
XxxForNextSentencePrediction, XxxForPreTraining,
|
||||
XxxForQuestionAnswering, XxxForSequenceClassification,
|
||||
XxxForTokenClassification, XxxForMultipleChoice)
|
||||
from transformers import (
|
||||
XxxConfig,
|
||||
XxxModel,
|
||||
XxxForMaskedLM,
|
||||
XxxForNextSentencePrediction,
|
||||
XxxForPreTraining,
|
||||
XxxForQuestionAnswering,
|
||||
XxxForSequenceClassification,
|
||||
XxxForTokenClassification,
|
||||
XxxForMultipleChoice,
|
||||
)
|
||||
from transformers.modeling_xxx import XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@require_torch
|
||||
class XxxModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (XxxModel, XxxForMaskedLM, XxxForQuestionAnswering,
|
||||
XxxForSequenceClassification,
|
||||
XxxForTokenClassification) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(XxxModel, XxxForMaskedLM, XxxForQuestionAnswering, XxxForSequenceClassification, XxxForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
class XxxModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
@@ -118,16 +127,17 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range)
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
|
||||
def create_and_check_xxx_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = XxxModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -140,83 +150,98 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
|
||||
def create_and_check_xxx_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = XxxForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels)
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_xxx_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = XxxForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
loss, start_logits, end_logits = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_xxx_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = XxxForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.num_labels])
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_xxx_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = XxxForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.seq_length, self.num_labels])
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, token_type_ids, input_mask,
|
||||
sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
@@ -252,5 +277,6 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
|
||||
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -18,10 +18,11 @@ import os
|
||||
import unittest
|
||||
from io import open
|
||||
|
||||
from transformers.tokenization_bert import (XxxTokenizer, VOCAB_FILES_NAMES)
|
||||
from transformers.tokenization_bert import XxxTokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
from .tokenization_tests_commons import CommonTestCases
|
||||
|
||||
|
||||
class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
|
||||
tokenizer_class = XxxTokenizer
|
||||
@@ -30,28 +31,39 @@ class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
super(XxxTokenizationTest, self).setUp()
|
||||
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing", ",", "low", "lowest",
|
||||
"[UNK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"want",
|
||||
"##want",
|
||||
"##ed",
|
||||
"wa",
|
||||
"un",
|
||||
"runn",
|
||||
"##ing",
|
||||
",",
|
||||
"low",
|
||||
"lowest",
|
||||
]
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XxxTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"UNwant\u00E9d,running"
|
||||
output_text = u"unwanted, running"
|
||||
input_text = "UNwant\u00E9d,running"
|
||||
output_text = "unwanted, running"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class(self.vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||
tokens = tokenizer.tokenize("UNwant\u00E9d,running")
|
||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -34,17 +34,16 @@ logger = logging.getLogger(__name__)
|
||||
# Mapping from the keyword arguments names of Tokenizer `__init__`
|
||||
# to file names for serializing Tokenizer instances
|
||||
####################################################
|
||||
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||
|
||||
####################################################
|
||||
# Mapping from the keyword arguments names of Tokenizer `__init__`
|
||||
# to pretrained vocabulary URL for all the model shortcut names.
|
||||
####################################################
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'vocab_file':
|
||||
{
|
||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-vocab.txt",
|
||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-vocab.txt",
|
||||
"vocab_file": {
|
||||
"xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-vocab.txt",
|
||||
"xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,8 +51,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
||||
# Mapping from model shortcut names to max length of inputs
|
||||
####################################################
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'xxx-base-uncased': 512,
|
||||
'xxx-large-uncased': 512,
|
||||
"xxx-base-uncased": 512,
|
||||
"xxx-large-uncased": 512,
|
||||
}
|
||||
|
||||
####################################################
|
||||
@@ -62,8 +61,8 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
# To be used for checkpoint specific configurations.
|
||||
####################################################
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
'xxx-base-uncased': {'do_lower_case': True},
|
||||
'xxx-large-uncased': {'do_lower_case': True},
|
||||
"xxx-base-uncased": {"do_lower_case": True},
|
||||
"xxx-large-uncased": {"do_lower_case": True},
|
||||
}
|
||||
|
||||
|
||||
@@ -73,7 +72,7 @@ def load_vocab(vocab_file):
|
||||
with open(vocab_file, "r", encoding="utf-8") as reader:
|
||||
tokens = reader.readlines()
|
||||
for index, token in enumerate(tokens):
|
||||
token = token.rstrip('\n')
|
||||
token = token.rstrip("\n")
|
||||
vocab[token] = index
|
||||
return vocab
|
||||
|
||||
@@ -93,9 +92,17 @@ class XxxTokenizer(PreTrainedTokenizer):
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True,
|
||||
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
|
||||
mask_token="[MASK]", **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
do_lower_case=True,
|
||||
unk_token="[UNK]",
|
||||
sep_token="[SEP]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
mask_token="[MASK]",
|
||||
**kwargs
|
||||
):
|
||||
"""Constructs a XxxTokenizer.
|
||||
|
||||
Args:
|
||||
@@ -104,16 +111,22 @@ class XxxTokenizer(PreTrainedTokenizer):
|
||||
Whether to lower case the input
|
||||
Only has an effect when do_basic_tokenize=True
|
||||
"""
|
||||
super(XxxTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
|
||||
pad_token=pad_token, cls_token=cls_token,
|
||||
mask_token=mask_token, **kwargs)
|
||||
super(XxxTokenizer, self).__init__(
|
||||
unk_token=unk_token,
|
||||
sep_token=sep_token,
|
||||
pad_token=pad_token,
|
||||
cls_token=cls_token,
|
||||
mask_token=mask_token,
|
||||
**kwargs
|
||||
)
|
||||
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
|
||||
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
|
||||
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||
"model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
|
||||
"model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
|
||||
)
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
|
||||
@property
|
||||
@@ -142,7 +155,7 @@ class XxxTokenizer(PreTrainedTokenizer):
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = ' '.join(tokens).replace(' ##', '').strip()
|
||||
out_string = " ".join(tokens).replace(" ##", "").strip()
|
||||
return out_string
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
@@ -177,8 +190,10 @@ class XxxTokenizer(PreTrainedTokenizer):
|
||||
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError("You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model.")
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model."
|
||||
)
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
|
||||
if token_ids_1 is not None:
|
||||
@@ -204,15 +219,17 @@ class XxxTokenizer(PreTrainedTokenizer):
|
||||
"""Save the tokenizer vocabulary to a directory or file."""
|
||||
index = 0
|
||||
if os.path.isdir(vocab_path):
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
||||
else:
|
||||
vocab_file = vocab_path
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
|
||||
" Please check that the vocabulary is not corrupted!".format(vocab_file))
|
||||
logger.warning(
|
||||
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
||||
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
||||
)
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
writer.write(token + "\n")
|
||||
index += 1
|
||||
return (vocab_file,)
|
||||
|
||||
Reference in New Issue
Block a user