💄 super
This commit is contained in:
@@ -127,7 +127,7 @@ ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, nx, n_ctx, config, scale=False):
|
||||
super(Attention, self).__init__()
|
||||
super().__init__()
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||
assert n_state % config.n_head == 0
|
||||
@@ -221,7 +221,7 @@ class Attention(nn.Module):
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
||||
super(MLP, self).__init__()
|
||||
super().__init__()
|
||||
nx = config.n_embd
|
||||
self.c_fc = Conv1D(n_state, nx)
|
||||
self.c_proj = Conv1D(nx, n_state)
|
||||
@@ -236,7 +236,7 @@ class MLP(nn.Module):
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_ctx, config, scale=False):
|
||||
super(Block, self).__init__()
|
||||
super().__init__()
|
||||
nx = config.n_embd
|
||||
self.attn = Attention(nx, n_ctx, config, scale)
|
||||
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
@@ -359,7 +359,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTModel, self).__init__(config)
|
||||
super().__init__(config)
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
@@ -518,7 +518,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTLMHeadModel, self).__init__(config)
|
||||
super().__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
@@ -623,7 +623,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
|
||||
super().__init__(config)
|
||||
|
||||
config.num_labels = 1
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
|
||||
Reference in New Issue
Block a user