clean up pr
This commit is contained in:
12
modeling.py
12
modeling.py
@@ -25,10 +25,7 @@ import six
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from six import string_types
|
||||||
|
|
||||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.ReLU, "swish": swish}
|
|
||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
"""Implementation of the gelu activation function.
|
"""Implementation of the gelu activation function.
|
||||||
@@ -42,6 +39,9 @@ def swish(x):
|
|||||||
return x * torch.sigmoid(x)
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||||
|
|
||||||
|
|
||||||
class BertConfig(object):
|
class BertConfig(object):
|
||||||
"""Configuration class to store the configuration of a `BertModel`.
|
"""Configuration class to store the configuration of a `BertModel`.
|
||||||
"""
|
"""
|
||||||
@@ -68,7 +68,7 @@ class BertConfig(object):
|
|||||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||||
layer in the Transformer encoder.
|
layer in the Transformer encoder.
|
||||||
hidden_act: The non-linear activation function (function or string) in the
|
hidden_act: The non-linear activation function (function or string) in the
|
||||||
encoder and pooler. If string, "gelu", "relu" and "swish" supported.
|
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
||||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||||
layers in the embeddings, encoder, and pooler.
|
layers in the embeddings, encoder, and pooler.
|
||||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||||
@@ -246,7 +246,7 @@ class BERTIntermediate(nn.Module):
|
|||||||
super(BERTIntermediate, self).__init__()
|
super(BERTIntermediate, self).__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
||||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
if isinstance(config.hidden_act, string_types) else config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
|
|||||||
Reference in New Issue
Block a user