[style] consistent nn. and nn.functional: part 4 examples (#12156)
* consistent nn. and nn.functional: p4 examples * restore
This commit is contained in:
@@ -1,19 +1,19 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ClassificationHead(torch.nn.Module):
|
||||
class ClassificationHead(nn.Module):
|
||||
"""Classification Head for transformer encoders"""
|
||||
|
||||
def __init__(self, class_size, embed_size):
|
||||
super().__init__()
|
||||
self.class_size = class_size
|
||||
self.embed_size = embed_size
|
||||
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
|
||||
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
|
||||
self.mlp = torch.nn.Linear(embed_size, class_size)
|
||||
# self.mlp1 = nn.Linear(embed_size, embed_size)
|
||||
# self.mlp2 = (nn.Linear(embed_size, class_size))
|
||||
self.mlp = nn.Linear(embed_size, class_size)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
# hidden_state = F.relu(self.mlp1(hidden_state))
|
||||
# hidden_state = nn.functional.relu(self.mlp1(hidden_state))
|
||||
# hidden_state = self.mlp2(hidden_state)
|
||||
logits = self.mlp(hidden_state)
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user