output attentions in GPT-2
This commit is contained in:
@@ -223,7 +223,7 @@ class Conv1D(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, nx, n_ctx, config, scale=False):
|
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False):
|
||||||
super(Attention, self).__init__()
|
super(Attention, self).__init__()
|
||||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||||
@@ -232,6 +232,7 @@ class Attention(nn.Module):
|
|||||||
self.n_head = config.n_head
|
self.n_head = config.n_head
|
||||||
self.split_size = n_state
|
self.split_size = n_state
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
self.output_attentions = output_attentions
|
||||||
self.c_attn = Conv1D(n_state * 3, nx)
|
self.c_attn = Conv1D(n_state * 3, nx)
|
||||||
self.c_proj = Conv1D(n_state, nx)
|
self.c_proj = Conv1D(n_state, nx)
|
||||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||||
@@ -247,6 +248,8 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
w = nn.Softmax(dim=-1)(w)
|
w = nn.Softmax(dim=-1)(w)
|
||||||
w = self.attn_dropout(w)
|
w = self.attn_dropout(w)
|
||||||
|
if self.output_attentions:
|
||||||
|
return w, torch.matmul(w, v)
|
||||||
return torch.matmul(w, v)
|
return torch.matmul(w, v)
|
||||||
|
|
||||||
def merge_heads(self, x):
|
def merge_heads(self, x):
|
||||||
@@ -274,9 +277,13 @@ class Attention(nn.Module):
|
|||||||
value = torch.cat((past_value, value), dim=-2)
|
value = torch.cat((past_value, value), dim=-2)
|
||||||
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
||||||
a = self._attn(query, key, value)
|
a = self._attn(query, key, value)
|
||||||
|
if self.output_attentions:
|
||||||
|
attentions, a = a
|
||||||
a = self.merge_heads(a)
|
a = self.merge_heads(a)
|
||||||
a = self.c_proj(a)
|
a = self.c_proj(a)
|
||||||
a = self.resid_dropout(a)
|
a = self.resid_dropout(a)
|
||||||
|
if self.output_attentions:
|
||||||
|
return attentions, a, present
|
||||||
return a, present
|
return a, present
|
||||||
|
|
||||||
|
|
||||||
@@ -296,19 +303,26 @@ class MLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(self, n_ctx, config, scale=False):
|
def __init__(self, n_ctx, config, scale=False, output_attentions=False):
|
||||||
super(Block, self).__init__()
|
super(Block, self).__init__()
|
||||||
nx = config.n_embd
|
nx = config.n_embd
|
||||||
|
self.output_attentions = output_attentions
|
||||||
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||||
self.attn = Attention(nx, n_ctx, config, scale)
|
self.attn = Attention(nx, n_ctx, config, scale, output_attentions)
|
||||||
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||||
self.mlp = MLP(4 * nx, config)
|
self.mlp = MLP(4 * nx, config)
|
||||||
|
|
||||||
def forward(self, x, layer_past=None):
|
def forward(self, x, layer_past=None):
|
||||||
a, present = self.attn(self.ln_1(x), layer_past=layer_past)
|
output_attn = self.attn(self.ln_1(x), layer_past=layer_past)
|
||||||
|
if self.output_attentions:
|
||||||
|
attentions, a, present = output_attn
|
||||||
|
else:
|
||||||
|
a, present = output_attn
|
||||||
x = x + a
|
x = x + a
|
||||||
m = self.mlp(self.ln_2(x))
|
m = self.mlp(self.ln_2(x))
|
||||||
x = x + m
|
x = x + m
|
||||||
|
if self.output_attentions:
|
||||||
|
return attentions, x, present
|
||||||
return x, present
|
return x, present
|
||||||
|
|
||||||
|
|
||||||
@@ -567,12 +581,13 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(GPT2Model, self).__init__(config)
|
super(GPT2Model, self).__init__(config)
|
||||||
|
self.output_attentions = output_attentions
|
||||||
self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
|
self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
|
||||||
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
||||||
self.drop = nn.Dropout(config.embd_pdrop)
|
self.drop = nn.Dropout(config.embd_pdrop)
|
||||||
block = Block(config.n_ctx, config, scale=True)
|
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions)
|
||||||
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
||||||
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
@@ -617,11 +632,18 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
hidden_states = self.drop(hidden_states)
|
hidden_states = self.drop(hidden_states)
|
||||||
|
|
||||||
presents = []
|
presents = []
|
||||||
|
all_attentions = []
|
||||||
for block, layer_past in zip(self.h, past):
|
for block, layer_past in zip(self.h, past):
|
||||||
|
if self.output_attentions:
|
||||||
|
attentions, hidden_states, present = block(hidden_states, layer_past)
|
||||||
|
all_attentions.append(attentions)
|
||||||
|
else:
|
||||||
hidden_states, present = block(hidden_states, layer_past)
|
hidden_states, present = block(hidden_states, layer_past)
|
||||||
presents.append(present)
|
presents.append(present)
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
if self.output_attentions:
|
||||||
|
return all_attentions, hidden_states.view(*output_shape), presents
|
||||||
return hidden_states.view(*output_shape), presents
|
return hidden_states.view(*output_shape), presents
|
||||||
|
|
||||||
|
|
||||||
@@ -669,9 +691,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(GPT2LMHeadModel, self).__init__(config)
|
super(GPT2LMHeadModel, self).__init__(config)
|
||||||
self.transformer = GPT2Model(config)
|
self.transformer = GPT2Model(config, output_attentions=output_attentions)
|
||||||
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
@@ -684,7 +706,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
|
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
|
||||||
|
|
||||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
|
||||||
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past)
|
||||||
|
if self.transformer.output_attentions:
|
||||||
|
all_attentions, hidden_states, presents = transformer_output
|
||||||
|
else:
|
||||||
|
hidden_states, presents = transformer_output
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
if lm_labels is not None:
|
if lm_labels is not None:
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
@@ -695,6 +721,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||||
shift_labels.view(-1))
|
shift_labels.view(-1))
|
||||||
return loss
|
return loss
|
||||||
|
if self.transformer.output_attentions:
|
||||||
|
return all_attentions, lm_logits, presents
|
||||||
return lm_logits, presents
|
return lm_logits, presents
|
||||||
|
|
||||||
|
|
||||||
@@ -747,9 +775,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(GPT2DoubleHeadsModel, self).__init__(config)
|
super(GPT2DoubleHeadsModel, self).__init__(config)
|
||||||
self.transformer = GPT2Model(config)
|
self.transformer = GPT2Model(config, output_attentions=output_attentions)
|
||||||
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
||||||
self.multiple_choice_head = GPT2MultipleChoiceHead(config)
|
self.multiple_choice_head = GPT2MultipleChoiceHead(config)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
@@ -763,7 +791,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
|
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
|
||||||
|
|
||||||
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
|
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
|
||||||
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past)
|
||||||
|
if self.transformer.output_attentions:
|
||||||
|
all_attentions, hidden_states, presents = transformer_output
|
||||||
|
else:
|
||||||
|
hidden_states, presents = transformer_output
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||||
losses = []
|
losses = []
|
||||||
@@ -777,4 +809,6 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
||||||
if losses:
|
if losses:
|
||||||
return losses
|
return losses
|
||||||
|
if self.transformer.output_attentions:
|
||||||
|
return all_attentions, lm_logits, mc_logits, presents
|
||||||
return lm_logits, mc_logits, presents
|
return lm_logits, mc_logits, presents
|
||||||
|
|||||||
Reference in New Issue
Block a user