get_activation('relu') provides a simple mapping from strings i… (#2807)

* activations.py contains a mapping from string to activation function
* resolves some `gelu` vs `gelu_new` ambiguity
This commit is contained in:
Sam Shleifer
2020-02-13 08:28:33 -05:00
committed by GitHub
parent f54a5bd37f
commit ef74b0f07a
9 changed files with 94 additions and 68 deletions

View File

@@ -18,12 +18,14 @@
import logging
import os
import typing
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .file_utils import (
DUMMY_INPUTS,
@@ -1378,15 +1380,15 @@ class SequenceSummary(nn.Module):
- 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
summary_first_dropout: Add a dropout before the projection and activation
summary_last_dropout: Add a dropout after the projection and activation
"""
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.summary_type = config.summary_type if hasattr(config, "summary_type") else "last"
self.summary_type = getattr(config, "summary_type", "last")
if self.summary_type == "attn":
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
@@ -1401,9 +1403,10 @@ class SequenceSummary(nn.Module):
num_classes = config.hidden_size
self.summary = nn.Linear(config.hidden_size, num_classes)
self.activation = Identity()
if hasattr(config, "summary_activation") and config.summary_activation == "tanh":
self.activation = nn.Tanh()
activation_string = getattr(config, "summary_activation", None)
self.activation = (
get_activation(activation_string) if activation_string else Identity()
) # type: typing.Callable
self.first_dropout = Identity()
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: