613 lines
28 KiB
Python
613 lines
28 KiB
Python
import logging
|
|
import os
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from transformers import ElectraConfig, add_start_docstrings
|
|
from transformers.activations import get_activation
|
|
|
|
from .file_utils import add_start_docstrings_to_callable
|
|
from .modeling_bert import BertEmbeddings, BertEncoder, BertLayerNorm, BertPreTrainedModel
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|
"google/electra-small-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-generator/pytorch_model.bin",
|
|
"google/electra-base-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-generator/pytorch_model.bin",
|
|
"google/electra-large-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-large-generator/pytorch_model.bin",
|
|
"google/electra-small-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-discriminator/pytorch_model.bin",
|
|
"google/electra-base-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-discriminator/pytorch_model.bin",
|
|
"google/electra-large-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-large-discriminator/pytorch_model.bin",
|
|
}
|
|
|
|
|
|
def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_or_generator="discriminator"):
|
|
""" Load tf checkpoints in a pytorch model.
|
|
"""
|
|
try:
|
|
import re
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
logger.error(
|
|
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
|
"https://www.tensorflow.org/install/ for installation instructions."
|
|
)
|
|
raise
|
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
|
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
|
# Load weights from TF model
|
|
init_vars = tf.train.list_variables(tf_path)
|
|
names = []
|
|
arrays = []
|
|
for name, shape in init_vars:
|
|
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
|
array = tf.train.load_variable(tf_path, name)
|
|
names.append(name)
|
|
arrays.append(array)
|
|
for name, array in zip(names, arrays):
|
|
original_name: str = name
|
|
|
|
try:
|
|
if isinstance(model, ElectraForMaskedLM):
|
|
name = name.replace("electra/embeddings/", "generator/embeddings/")
|
|
|
|
if discriminator_or_generator == "generator":
|
|
name = name.replace("electra/", "discriminator/")
|
|
name = name.replace("generator/", "electra/")
|
|
|
|
name = name.replace("dense_1", "dense_prediction")
|
|
name = name.replace("generator_predictions/output_bias", "generator_lm_head/bias")
|
|
|
|
name = name.split("/")
|
|
# print(original_name, name)
|
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
|
# which are not required for using pretrained model
|
|
if any(n in ["global_step", "temperature"] for n in name):
|
|
logger.info("Skipping {}".format(original_name))
|
|
continue
|
|
pointer = model
|
|
for m_name in name:
|
|
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
|
scope_names = re.split(r"_(\d+)", m_name)
|
|
else:
|
|
scope_names = [m_name]
|
|
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
|
pointer = getattr(pointer, "weight")
|
|
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
|
pointer = getattr(pointer, "bias")
|
|
elif scope_names[0] == "output_weights":
|
|
pointer = getattr(pointer, "weight")
|
|
elif scope_names[0] == "squad":
|
|
pointer = getattr(pointer, "classifier")
|
|
else:
|
|
pointer = getattr(pointer, scope_names[0])
|
|
if len(scope_names) >= 2:
|
|
num = int(scope_names[1])
|
|
pointer = pointer[num]
|
|
if m_name.endswith("_embeddings"):
|
|
pointer = getattr(pointer, "weight")
|
|
elif m_name == "kernel":
|
|
array = np.transpose(array)
|
|
try:
|
|
assert pointer.shape == array.shape, original_name
|
|
except AssertionError as e:
|
|
e.args += (pointer.shape, array.shape)
|
|
raise
|
|
print("Initialize PyTorch weight {}".format(name), original_name)
|
|
pointer.data = torch.from_numpy(array)
|
|
except AttributeError as e:
|
|
print("Skipping {}".format(original_name), name, e)
|
|
continue
|
|
return model
|
|
|
|
|
|
class ElectraEmbeddings(BertEmbeddings):
|
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
|
|
|
|
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
|
# any TensorFlow checkpoint file
|
|
self.LayerNorm = BertLayerNorm(config.embedding_size, eps=config.layer_norm_eps)
|
|
|
|
|
|
class ElectraDiscriminatorPredictions(nn.Module):
|
|
"""Prediction module for the discriminator, made up of two dense layers."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.dense_prediction = nn.Linear(config.hidden_size, 1)
|
|
self.config = config
|
|
|
|
def forward(self, discriminator_hidden_states, attention_mask):
|
|
hidden_states = self.dense(discriminator_hidden_states)
|
|
hidden_states = get_activation(self.config.hidden_act)(hidden_states)
|
|
logits = self.dense_prediction(hidden_states).squeeze()
|
|
|
|
return logits
|
|
|
|
|
|
class ElectraGeneratorPredictions(nn.Module):
|
|
"""Prediction module for the generator, made up of two dense layers."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.LayerNorm = BertLayerNorm(config.embedding_size)
|
|
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
|
|
|
def forward(self, generator_hidden_states):
|
|
hidden_states = self.dense(generator_hidden_states)
|
|
hidden_states = get_activation("gelu")(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class ElectraPreTrainedModel(BertPreTrainedModel):
|
|
""" An abstract class to handle weights initialization and
|
|
a simple interface for downloading and loading pretrained models.
|
|
"""
|
|
|
|
config_class = ElectraConfig
|
|
pretrained_model_archive_map = ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
load_tf_weights = load_tf_weights_in_electra
|
|
base_model_prefix = "electra"
|
|
|
|
|
|
ELECTRA_START_DOCSTRING = r"""
|
|
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
|
usage and behavior.
|
|
|
|
Parameters:
|
|
config (:class:`~transformers.ElectraConfig`): Model configuration class with all the parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
|
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
|
"""
|
|
|
|
ELECTRA_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary.
|
|
|
|
Indices can be obtained using :class:`transformers.ElectraTokenizer`.
|
|
See :func:`transformers.PreTrainedTokenizer.encode` and
|
|
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
|
|
|
`What are input IDs? <../glossary.html#input-ids>`__
|
|
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
|
Mask to avoid performing attention on padding token indices.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
|
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
|
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
|
Segment token indices to indicate first and second portions of the inputs.
|
|
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
|
corresponds to a `sentence B` token
|
|
|
|
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
|
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
|
Indices of positions of each input sequence tokens in the position embeddings.
|
|
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
|
|
|
`What are position IDs? <../glossary.html#position-ids>`_
|
|
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
|
|
Mask to nullify selected heads of the self-attention modules.
|
|
Mask values selected in ``[0, 1]``:
|
|
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
|
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
than the model's internal embedding lookup matrix.
|
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
|
if the model is configured as a decoder.
|
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
|
is used in the cross-attention if the model is configured as a decoder.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
"""
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare Electra Model transformer outputting raw hidden-states without any specific head on top. Identical to "
|
|
"the BERT model except that it uses an additional linear layer between the embedding layer and the encoder if the "
|
|
"hidden size and embedding size are different."
|
|
""
|
|
"Both the generator and discriminator checkpoints may be loaded into this model.",
|
|
ELECTRA_START_DOCSTRING,
|
|
)
|
|
class ElectraModel(ElectraPreTrainedModel):
|
|
|
|
config_class = ElectraConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.embeddings = ElectraEmbeddings(config)
|
|
|
|
if config.embedding_size != config.hidden_size:
|
|
self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size)
|
|
|
|
self.encoder = BertEncoder(config)
|
|
self.config = config
|
|
self.init_weights()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings.word_embeddings
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embeddings.word_embeddings = value
|
|
|
|
def _prune_heads(self, heads_to_prune):
|
|
""" Prunes heads of the model.
|
|
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
|
See base class PreTrainedModel
|
|
"""
|
|
for layer, heads in heads_to_prune.items():
|
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
|
|
|
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
):
|
|
r"""
|
|
Return:
|
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
|
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
|
Sequence of hidden-states at the output of the last layer of the model.
|
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
|
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
|
heads.
|
|
|
|
Examples::
|
|
|
|
from transformers import ElectraModel, ElectraTokenizer
|
|
import torch
|
|
|
|
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
|
|
model = ElectraModel.from_pretrained('google/electra-small-discriminator')
|
|
|
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
|
outputs = model(input_ids)
|
|
|
|
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
|
|
|
"""
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(input_shape, device=device)
|
|
if token_type_ids is None:
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
|
|
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
|
|
|
hidden_states = self.embeddings(
|
|
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
|
)
|
|
|
|
if hasattr(self, "embeddings_project"):
|
|
hidden_states = self.embeddings_project(hidden_states)
|
|
|
|
hidden_states = self.encoder(hidden_states, attention_mask=extended_attention_mask, head_mask=head_mask)
|
|
|
|
return hidden_states
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""
|
|
Electra model with a binary classification head on top as used during pre-training for identifying generated
|
|
tokens.
|
|
|
|
It is recommended to load the discriminator checkpoint into that model.""",
|
|
ELECTRA_START_DOCSTRING,
|
|
)
|
|
class ElectraForPreTraining(ElectraPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.electra = ElectraModel(config)
|
|
self.discriminator_predictions = ElectraDiscriminatorPredictions(config)
|
|
self.init_weights()
|
|
|
|
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
labels=None,
|
|
):
|
|
r"""
|
|
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
|
|
Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see :obj:`input_ids` docstring)
|
|
Indices should be in ``[0, 1]``.
|
|
``0`` indicates the token is an original token,
|
|
``1`` indicates the token was replaced.
|
|
|
|
Returns:
|
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
|
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
|
Total loss of the ELECTRA objective.
|
|
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`)
|
|
Prediction scores of the head (scores for each token before SoftMax).
|
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
|
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
|
heads.
|
|
|
|
|
|
Examples::
|
|
|
|
from transformers import ElectraTokenizer, ElectraForPreTraining
|
|
import torch
|
|
|
|
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
|
|
model = ElectraForPreTraining.from_pretrained('google/electra-small-discriminator')
|
|
|
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
|
outputs = model(input_ids)
|
|
|
|
prediction_scores, seq_relationship_scores = outputs[:2]
|
|
|
|
"""
|
|
|
|
discriminator_hidden_states = self.electra(
|
|
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds
|
|
)
|
|
discriminator_sequence_output = discriminator_hidden_states[0]
|
|
|
|
logits = self.discriminator_predictions(discriminator_sequence_output, attention_mask)
|
|
|
|
output = (logits,)
|
|
|
|
if labels is not None:
|
|
loss_fct = nn.BCEWithLogitsLoss()
|
|
if attention_mask is not None:
|
|
active_loss = attention_mask.view(-1, discriminator_sequence_output.shape[1]) == 1
|
|
active_logits = logits.view(-1, discriminator_sequence_output.shape[1])[active_loss]
|
|
active_labels = labels[active_loss]
|
|
loss = loss_fct(active_logits, active_labels.float())
|
|
else:
|
|
loss = loss_fct(logits.view(-1, discriminator_sequence_output.shape[1]), labels.float())
|
|
|
|
output = (loss,) + output
|
|
|
|
output += discriminator_hidden_states[1:]
|
|
|
|
return output # (loss), scores, (hidden_states), (attentions)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""
|
|
Electra model with a language modeling head on top.
|
|
|
|
Even though both the discriminator and generator may be loaded into this model, the generator is
|
|
the only model of the two to have been trained for the masked language modeling task.""",
|
|
ELECTRA_START_DOCSTRING,
|
|
)
|
|
class ElectraForMaskedLM(ElectraPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.electra = ElectraModel(config)
|
|
self.generator_predictions = ElectraGeneratorPredictions(config)
|
|
|
|
self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
|
|
self.init_weights()
|
|
|
|
def get_output_embeddings(self):
|
|
return self.generator_lm_head
|
|
|
|
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
masked_lm_labels=None,
|
|
):
|
|
r"""
|
|
masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
|
Labels for computing the masked language modeling loss.
|
|
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
|
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
|
in ``[0, ..., config.vocab_size]``
|
|
|
|
Returns:
|
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
|
masked_lm_loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
|
Masked language modeling loss.
|
|
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
|
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
|
heads.
|
|
|
|
Examples::
|
|
|
|
from transformers import ElectraTokenizer, ElectraForMaskedLM
|
|
import torch
|
|
|
|
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-generator')
|
|
model = ElectraForMaskedLM.from_pretrained('google/electra-small-generator')
|
|
|
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
|
outputs = model(input_ids, masked_lm_labels=input_ids)
|
|
|
|
loss, prediction_scores = outputs[:2]
|
|
|
|
"""
|
|
|
|
generator_hidden_states = self.electra(
|
|
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds
|
|
)
|
|
generator_sequence_output = generator_hidden_states[0]
|
|
|
|
prediction_scores = self.generator_predictions(generator_sequence_output)
|
|
prediction_scores = self.generator_lm_head(prediction_scores)
|
|
|
|
output = (prediction_scores,)
|
|
|
|
# Masked language modeling softmax layer
|
|
if masked_lm_labels is not None:
|
|
loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
|
|
loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
|
output = (loss,) + output
|
|
|
|
output += generator_hidden_states[1:]
|
|
|
|
return output # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""
|
|
Electra model with a token classification head on top.
|
|
|
|
Both the discriminator and generator may be loaded into this model.""",
|
|
ELECTRA_START_DOCSTRING,
|
|
)
|
|
class ElectraForTokenClassification(ElectraPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.electra = ElectraModel(config)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
|
self.init_weights()
|
|
|
|
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
labels=None,
|
|
):
|
|
r"""
|
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
|
Labels for computing the token classification loss.
|
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
|
|
|
Returns:
|
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
|
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
|
|
Classification loss.
|
|
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
|
|
Classification scores (before SoftMax).
|
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
|
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
|
heads.
|
|
|
|
Examples::
|
|
|
|
from transformers import ElectraTokenizer, ElectraForTokenClassification
|
|
import torch
|
|
|
|
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
|
|
model = ElectraForTokenClassification.from_pretrained('google/electra-small-discriminator')
|
|
|
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
|
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
|
|
outputs = model(input_ids, labels=labels)
|
|
|
|
loss, scores = outputs[:2]
|
|
|
|
"""
|
|
|
|
discriminator_hidden_states = self.electra(
|
|
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds
|
|
)
|
|
discriminator_sequence_output = discriminator_hidden_states[0]
|
|
|
|
discriminator_sequence_output = self.dropout(discriminator_sequence_output)
|
|
logits = self.classifier(discriminator_sequence_output)
|
|
|
|
output = (logits,)
|
|
|
|
if labels is not None:
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
# Only keep active parts of the loss
|
|
if attention_mask is not None:
|
|
active_loss = attention_mask.view(-1) == 1
|
|
active_logits = logits.view(-1, self.config.num_labels)[active_loss]
|
|
active_labels = labels.view(-1)[active_loss]
|
|
loss = loss_fct(active_logits, active_labels)
|
|
else:
|
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
|
|
|
output = (loss,) + output
|
|
|
|
output += discriminator_hidden_states[1:]
|
|
|
|
return output # (loss), scores, (hidden_states), (attentions)
|