do some (light) housekeeping
Several packages were imported but never used, indentation and line spaces did not follow PEP8.
This commit is contained in:
@@ -17,12 +17,10 @@
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -50,6 +48,7 @@ BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
|
||||
}
|
||||
|
||||
|
||||
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model.
|
||||
"""
|
||||
@@ -125,12 +124,14 @@ def gelu(x):
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def gelu_new(x):
|
||||
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
@@ -140,6 +141,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_
|
||||
|
||||
BertLayerNorm = torch.nn.LayerNorm
|
||||
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
@@ -482,7 +484,7 @@ BERT_START_DOCSTRING = r""" The BERT model was proposed in
|
||||
https://pytorch.org/docs/stable/nn.html#module
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
||||
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
@@ -496,13 +498,13 @@ BERT_INPUTS_DOCSTRING = r"""
|
||||
(a) For sequence pairs:
|
||||
|
||||
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
|
||||
|
||||
|
||||
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
|
||||
|
||||
(b) For single sequences:
|
||||
|
||||
``tokens: [CLS] the dog is hairy . [SEP]``
|
||||
|
||||
|
||||
``token_type_ids: 0 0 0 0 0 0 0``
|
||||
|
||||
Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||
@@ -601,7 +603,7 @@ class BertModel(BertPreTrainedModel):
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
@@ -615,7 +617,7 @@ class BertModel(BertPreTrainedModel):
|
||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
@@ -631,8 +633,9 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
|
||||
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertForPreTraining(BertPreTrainedModel):
|
||||
r"""
|
||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -692,7 +695,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
@@ -711,7 +714,8 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertForMaskedLM(BertPreTrainedModel):
|
||||
r"""
|
||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -764,7 +768,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@@ -780,7 +784,8 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
r"""
|
||||
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -825,7 +830,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@@ -842,8 +847,9 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertForSequenceClassification(BertPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -891,7 +897,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@@ -915,8 +921,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
|
||||
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertForMultipleChoice(BertPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -990,8 +997,9 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertForTokenClassification(BertPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -1037,7 +1045,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@@ -1062,8 +1070,9 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -1116,7 +1125,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
Reference in New Issue
Block a user