do some (light) housekeeping
Several packages were imported but never used, indentation and line spaces did not follow PEP8.
This commit is contained in:
@@ -17,12 +17,10 @@
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -50,6 +48,7 @@ BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
|
||||
}
|
||||
|
||||
|
||||
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model.
|
||||
"""
|
||||
@@ -125,12 +124,14 @@ def gelu(x):
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def gelu_new(x):
|
||||
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
@@ -140,6 +141,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_
|
||||
|
||||
BertLayerNorm = torch.nn.LayerNorm
|
||||
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
@@ -632,7 +634,8 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
|
||||
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertForPreTraining(BertPreTrainedModel):
|
||||
r"""
|
||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -711,7 +714,8 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertForMaskedLM(BertPreTrainedModel):
|
||||
r"""
|
||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -780,7 +784,8 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
r"""
|
||||
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -843,7 +848,8 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertForSequenceClassification(BertPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -916,7 +922,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
|
||||
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertForMultipleChoice(BertPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -991,7 +998,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertForTokenClassification(BertPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -1063,7 +1071,8 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
|
||||
Reference in New Issue
Block a user