get_activation('relu') provides a simple mapping from strings i… (#2807)
* activations.py contains a mapping from string to activation function * resolves some `gelu` vs `gelu_new` ambiguity
This commit is contained in:
@@ -18,12 +18,14 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import typing
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .activations import get_activation
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import (
|
||||
DUMMY_INPUTS,
|
||||
@@ -1378,15 +1380,15 @@ class SequenceSummary(nn.Module):
|
||||
- 'attn' => Not implemented now, use multi-head attention
|
||||
summary_use_proj: Add a projection after the vector extraction
|
||||
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
|
||||
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
|
||||
summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
|
||||
summary_first_dropout: Add a dropout before the projection and activation
|
||||
summary_last_dropout: Add a dropout after the projection and activation
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
|
||||
self.summary_type = config.summary_type if hasattr(config, "summary_type") else "last"
|
||||
self.summary_type = getattr(config, "summary_type", "last")
|
||||
if self.summary_type == "attn":
|
||||
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||
@@ -1401,9 +1403,10 @@ class SequenceSummary(nn.Module):
|
||||
num_classes = config.hidden_size
|
||||
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||
|
||||
self.activation = Identity()
|
||||
if hasattr(config, "summary_activation") and config.summary_activation == "tanh":
|
||||
self.activation = nn.Tanh()
|
||||
activation_string = getattr(config, "summary_activation", None)
|
||||
self.activation = (
|
||||
get_activation(activation_string) if activation_string else Identity()
|
||||
) # type: typing.Callable
|
||||
|
||||
self.first_dropout = Identity()
|
||||
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||
|
||||
Reference in New Issue
Block a user