From d64db6dfb94b302da876c03c989acf34deaa4ed4 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Tue, 13 Nov 2018 16:41:01 +0100 Subject: [PATCH] clean up pr --- modeling.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modeling.py b/modeling.py index 3b3f198c92..53243e5eb4 100644 --- a/modeling.py +++ b/modeling.py @@ -25,10 +25,7 @@ import six import torch import torch.nn as nn from torch.nn import CrossEntropyLoss - - -ACT2FN = {"gelu": gelu, "relu": torch.nn.ReLU, "swish": swish} - +from six import string_types def gelu(x): """Implementation of the gelu activation function. @@ -42,6 +39,9 @@ def swish(x): return x * torch.sigmoid(x) +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + class BertConfig(object): """Configuration class to store the configuration of a `BertModel`. """ @@ -68,7 +68,7 @@ class BertConfig(object): intermediate_size: The size of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. hidden_act: The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu" and "swish" supported. + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. hidden_dropout_prob: The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. attention_probs_dropout_prob: The dropout ratio for the attention @@ -246,7 +246,7 @@ class BERTIntermediate(nn.Module): super(BERTIntermediate, self).__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.intermediate_act_fn = ACT2FN[config.hidden_act] \ - if isinstance(config.hidden_act, str) else config.hidden_act + if isinstance(config.hidden_act, string_types) else config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states)