get_activation('relu') provides a simple mapping from strings i… (#2807)
* activations.py contains a mapping from string to activation function * resolves some `gelu` vs `gelu_new` ambiguity
This commit is contained in:
@@ -25,6 +25,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .activations import gelu_new, swish
|
||||
from .configuration_openai import OpenAIGPTConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
|
||||
@@ -114,15 +115,7 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
|
||||
return model
|
||||
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
|
||||
ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu_new}
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user