From e211785ada54a761bd1b0f10e3878539d9d1d72d Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 2 May 2019 18:31:26 +0200 Subject: [PATCH] extract attention weights from GPT --- pytorch_pretrained_bert/modeling_openai.py | 48 +++++++++++++++++----- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 7ac3782b42..77e9cda349 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -253,7 +253,7 @@ class Conv1D(nn.Module): class Attention(nn.Module): - def __init__(self, nx, n_ctx, config, scale=False): + def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False): super(Attention, self).__init__() n_state = nx # in Attention: n_state=768 (nx=n_embd) # [switch nx => n_state from Block to Attention to keep identical to TF implem] @@ -262,6 +262,7 @@ class Attention(nn.Module): self.n_head = config.n_head self.split_size = n_state self.scale = scale + self.output_attentions = output_attentions self.c_attn = Conv1D(n_state * 3, 1, nx) self.c_proj = Conv1D(n_state, 1, nx) self.attn_dropout = nn.Dropout(config.attn_pdrop) @@ -278,6 +279,8 @@ class Attention(nn.Module): w = nn.Softmax(dim=-1)(w) w = self.attn_dropout(w) + if self.output_attentions: + return w, torch.matmul(w, v) return torch.matmul(w, v) def merge_heads(self, x): @@ -300,9 +303,13 @@ class Attention(nn.Module): key = self.split_heads(key, k=True) value = self.split_heads(value) a = self._attn(query, key, value) + if self.output_attentions: + attentions, a = a a = self.merge_heads(a) a = self.c_proj(a) a = self.resid_dropout(a) + if self.output_attentions: + return attentions, a return a @@ -322,19 +329,24 @@ class MLP(nn.Module): class Block(nn.Module): - def __init__(self, n_ctx, config, scale=False): + def __init__(self, n_ctx, config, scale=False, output_attentions=False): super(Block, self).__init__() nx = config.n_embd - self.attn = Attention(nx, n_ctx, config, scale) + self.output_attentions = output_attentions + self.attn = Attention(nx, n_ctx, config, scale, output_attentions) self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.mlp = MLP(4 * nx, config) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) def forward(self, x): a = self.attn(x) + if self.output_attentions: + attentions, a = a n = self.ln_1(x + a) m = self.mlp(n) h = self.ln_2(n + m) + if self.output_attentions: + return attentions, h return h @@ -591,12 +603,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ``` """ - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(OpenAIGPTModel, self).__init__(config) + self.output_attentions = output_attentions self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd) self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) self.drop = nn.Dropout(config.embd_pdrop) - block = Block(config.n_ctx, config, scale=True) + block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.apply(self.init_weights) @@ -639,9 +652,16 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): # Add the position information to the input embeddings # h = e.sum(dim=2) hidden_states = inputs_embeds + position_embeds + token_type_embeds + all_attentions = [] for block in self.h: - hidden_states = block(hidden_states) + if self.output_attentions: + attentions, hidden_states = block(hidden_states) + all_attentions.append(attentions) + else: + hidden_states = block(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) + if self.output_attentions: + return all_attentions, hidden_states.view(*output_shape) return hidden_states.view(*output_shape) @@ -701,9 +721,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ``` """ - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(OpenAIGPTLMHeadModel, self).__init__(config) - self.transformer = OpenAIGPTModel(config) + self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions) self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) self.apply(self.init_weights) @@ -716,6 +736,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None): hidden_states = self.transformer(input_ids, position_ids, token_type_ids) + if self.transformer.output_attentions: + all_attentions, hidden_states = hidden_states lm_logits = self.lm_head(hidden_states) if lm_labels is not None: # Shift so that tokens < n predict n @@ -726,6 +748,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return loss + if self.transformer.output_attentions: + return all_attentions, lm_logits return lm_logits @@ -790,9 +814,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ``` """ - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(OpenAIGPTDoubleHeadsModel, self).__init__(config) - self.transformer = OpenAIGPTModel(config) + self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions) self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config) self.apply(self.init_weights) @@ -806,6 +830,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None): hidden_states = self.transformer(input_ids, position_ids, token_type_ids) + if self.transformer.output_attentions: + all_attentions, hidden_states = hidden_states lm_logits = self.lm_head(hidden_states) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) losses = [] @@ -819,4 +845,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))) if losses: return losses + if self.transformer.output_attentions: + return all_attentions, lm_logits, mc_logits return lm_logits, mc_logits