From 275179a0033c9e065f97ea3bd3f0a8d7e0c4a17a Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 8 May 2019 22:24:42 +0200 Subject: [PATCH] output attentions in GPT-2 --- pytorch_pretrained_bert/modeling_gpt2.py | 60 +++++++++++++++++++----- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index 0554442b7f..d462fe04ef 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -223,7 +223,7 @@ class Conv1D(nn.Module): class Attention(nn.Module): - def __init__(self, nx, n_ctx, config, scale=False): + def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False): super(Attention, self).__init__() n_state = nx # in Attention: n_state=768 (nx=n_embd) # [switch nx => n_state from Block to Attention to keep identical to TF implem] @@ -232,6 +232,7 @@ class Attention(nn.Module): self.n_head = config.n_head self.split_size = n_state self.scale = scale + self.output_attentions = output_attentions self.c_attn = Conv1D(n_state * 3, nx) self.c_proj = Conv1D(n_state, nx) self.attn_dropout = nn.Dropout(config.attn_pdrop) @@ -247,6 +248,8 @@ class Attention(nn.Module): w = nn.Softmax(dim=-1)(w) w = self.attn_dropout(w) + if self.output_attentions: + return w, torch.matmul(w, v) return torch.matmul(w, v) def merge_heads(self, x): @@ -274,9 +277,13 @@ class Attention(nn.Module): value = torch.cat((past_value, value), dim=-2) present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking a = self._attn(query, key, value) + if self.output_attentions: + attentions, a = a a = self.merge_heads(a) a = self.c_proj(a) a = self.resid_dropout(a) + if self.output_attentions: + return attentions, a, present return a, present @@ -296,19 +303,26 @@ class MLP(nn.Module): class Block(nn.Module): - def __init__(self, n_ctx, config, scale=False): + def __init__(self, n_ctx, config, scale=False, output_attentions=False): super(Block, self).__init__() nx = config.n_embd + self.output_attentions = output_attentions self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) - self.attn = Attention(nx, n_ctx, config, scale) + self.attn = Attention(nx, n_ctx, config, scale, output_attentions) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.mlp = MLP(4 * nx, config) def forward(self, x, layer_past=None): - a, present = self.attn(self.ln_1(x), layer_past=layer_past) + output_attn = self.attn(self.ln_1(x), layer_past=layer_past) + if self.output_attentions: + attentions, a, present = output_attn + else: + a, present = output_attn x = x + a m = self.mlp(self.ln_2(x)) x = x + m + if self.output_attentions: + return attentions, x, present return x, present @@ -567,12 +581,13 @@ class GPT2Model(GPT2PreTrainedModel): ``` """ - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(GPT2Model, self).__init__(config) + self.output_attentions = output_attentions self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.drop = nn.Dropout(config.embd_pdrop) - block = Block(config.n_ctx, config, scale=True) + block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) @@ -617,11 +632,18 @@ class GPT2Model(GPT2PreTrainedModel): hidden_states = self.drop(hidden_states) presents = [] + all_attentions = [] for block, layer_past in zip(self.h, past): - hidden_states, present = block(hidden_states, layer_past) + if self.output_attentions: + attentions, hidden_states, present = block(hidden_states, layer_past) + all_attentions.append(attentions) + else: + hidden_states, present = block(hidden_states, layer_past) presents.append(present) hidden_states = self.ln_f(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) + if self.output_attentions: + return all_attentions, hidden_states.view(*output_shape), presents return hidden_states.view(*output_shape), presents @@ -669,9 +691,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ``` """ - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(GPT2LMHeadModel, self).__init__(config) - self.transformer = GPT2Model(config) + self.transformer = GPT2Model(config, output_attentions=output_attentions) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.apply(self.init_weights) @@ -684,7 +706,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens) def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): - hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) + transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past) + if self.transformer.output_attentions: + all_attentions, hidden_states, presents = transformer_output + else: + hidden_states, presents = transformer_output lm_logits = self.lm_head(hidden_states) if lm_labels is not None: # Shift so that tokens < n predict n @@ -695,6 +721,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return loss + if self.transformer.output_attentions: + return all_attentions, lm_logits, presents return lm_logits, presents @@ -747,9 +775,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ``` """ - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(GPT2DoubleHeadsModel, self).__init__(config) - self.transformer = GPT2Model(config) + self.transformer = GPT2Model(config, output_attentions=output_attentions) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.multiple_choice_head = GPT2MultipleChoiceHead(config) self.apply(self.init_weights) @@ -763,7 +791,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens) def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None): - hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) + transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past) + if self.transformer.output_attentions: + all_attentions, hidden_states, presents = transformer_output + else: + hidden_states, presents = transformer_output lm_logits = self.lm_head(hidden_states) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) losses = [] @@ -777,4 +809,6 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))) if losses: return losses + if self.transformer.output_attentions: + return all_attentions, lm_logits, mc_logits, presents return lm_logits, mc_logits, presents