Black 20 release
This commit is contained in:
@@ -64,8 +64,7 @@ MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = ["google/mobilebert-uncased"]
|
||||
|
||||
|
||||
def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model.
|
||||
"""
|
||||
"""Load tf checkpoints in a pytorch model."""
|
||||
try:
|
||||
import re
|
||||
|
||||
@@ -161,8 +160,7 @@ NORM2FN = {"layer_norm": torch.nn.LayerNorm, "no_norm": NoNorm}
|
||||
|
||||
|
||||
class MobileBertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@@ -663,8 +661,8 @@ class MobileBertPreTrainingHeads(nn.Module):
|
||||
|
||||
|
||||
class MobileBertPreTrainedModel(PreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for downloading and loading pretrained models.
|
||||
"""An abstract class to handle weights initialization and
|
||||
a simple interface for downloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = MobileBertConfig
|
||||
@@ -788,7 +786,7 @@ MOBILEBERT_INPUTS_DOCSTRING = r"""
|
||||
)
|
||||
class MobileBertModel(MobileBertPreTrainedModel):
|
||||
"""
|
||||
https://arxiv.org/pdf/2004.02984.pdf
|
||||
https://arxiv.org/pdf/2004.02984.pdf
|
||||
"""
|
||||
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
@@ -809,9 +807,9 @@ class MobileBertModel(MobileBertPreTrainedModel):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
See base class PreTrainedModel
|
||||
"""Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
See base class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
@@ -965,31 +963,31 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the masked language modeling loss.
|
||||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
|
||||
Indices should be in ``[0, 1]``.
|
||||
``0`` indicates sequence B is a continuation of sequence A,
|
||||
``1`` indicates sequence B is a random sequence.
|
||||
Returns:
|
||||
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the masked language modeling loss.
|
||||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
|
||||
Indices should be in ``[0, 1]``.
|
||||
``0`` indicates sequence B is a continuation of sequence A,
|
||||
``1`` indicates sequence B is a random sequence.
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MobileBertTokenizer, MobileBertForPreTraining
|
||||
>>> import torch
|
||||
>>> from transformers import MobileBertTokenizer, MobileBertForPreTraining
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased")
|
||||
>>> model = MobileBertForPreTraining.from_pretrained("google/mobilebert-uncased", return_dict=True)
|
||||
>>> tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased")
|
||||
>>> model = MobileBertForPreTraining.from_pretrained("google/mobilebert-uncased", return_dict=True)
|
||||
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids)
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids)
|
||||
|
||||
>>> prediction_logits = outptus.prediction_logits
|
||||
>>> seq_relationship_logits = outputs.seq_relationship_logits
|
||||
>>> prediction_logits = outptus.prediction_logits
|
||||
>>> seq_relationship_logits = outputs.seq_relationship_logits
|
||||
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
@@ -1176,29 +1174,29 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
|
||||
Indices should be in ``[0, 1]``.
|
||||
``0`` indicates sequence B is a continuation of sequence A,
|
||||
``1`` indicates sequence B is a random sequence.
|
||||
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
|
||||
Indices should be in ``[0, 1]``.
|
||||
``0`` indicates sequence B is a continuation of sequence A,
|
||||
``1`` indicates sequence B is a random sequence.
|
||||
|
||||
Returns:
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MobileBertTokenizer, MobileBertForNextSentencePrediction
|
||||
>>> import torch
|
||||
>>> from transformers import MobileBertTokenizer, MobileBertForNextSentencePrediction
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = MobileBertTokenizer.from_pretrained('google/mobilebert-uncased')
|
||||
>>> model = MobileBertForNextSentencePrediction.from_pretrained('google/mobilebert-uncased', return_dict=True)
|
||||
>>> tokenizer = MobileBertTokenizer.from_pretrained('google/mobilebert-uncased')
|
||||
>>> model = MobileBertForNextSentencePrediction.from_pretrained('google/mobilebert-uncased', return_dict=True)
|
||||
|
||||
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
|
||||
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
|
||||
|
||||
>>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1]))
|
||||
>>> loss = outputs.loss
|
||||
>>> logits = outputs.logits
|
||||
>>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1]))
|
||||
>>> loss = outputs.loss
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@@ -1308,7 +1306,10 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -1491,7 +1492,10 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return MultipleChoiceModelOutput(
|
||||
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
|
||||
loss=loss,
|
||||
logits=reshaped_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -1574,5 +1578,8 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user