From f12007e4216e2bdb278b40c731eb62626181cd28 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 17 Jun 2019 14:19:40 +0200 Subject: [PATCH] add head masking and pruning to openai GPT --- pytorch_pretrained_bert/modeling_openai.py | 100 ++++++++++++++++----- tests/modeling_openai_test.py | 70 +++++++++++++++ 2 files changed, 149 insertions(+), 21 deletions(-) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 333cd88f65..bbc60ffd2c 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -36,6 +36,7 @@ from torch.nn.parameter import Parameter from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME from .modeling import BertLayerNorm as LayerNorm +from .modeling_gpt2 import prune_conv1d_layer logger = logging.getLogger(__name__) @@ -256,7 +257,7 @@ class Conv1D(nn.Module): class Attention(nn.Module): - def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False): + def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False): super(Attention, self).__init__() n_state = nx # in Attention: n_state=768 (nx=n_embd) # [switch nx => n_state from Block to Attention to keep identical to TF implem] @@ -265,13 +266,31 @@ class Attention(nn.Module): self.n_head = config.n_head self.split_size = n_state self.scale = scale + self.output_attentions = output_attentions + self.keep_multihead_output = keep_multihead_output + self.multihead_output = None + self.c_attn = Conv1D(n_state * 3, 1, nx) self.c_proj = Conv1D(n_state, 1, nx) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) - def _attn(self, q, k, v): + def prune_heads(self, heads): + mask = torch.ones(self.n_head, self.split_size // self.n_head) + for head in heads: + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)]) + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + # Update hyper params + self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) + self.n_head = self.n_head - len(heads) + + def _attn(self, q, k, v, head_mask=None): w = torch.matmul(q, k) if self.scale: w = w / math.sqrt(v.size(-1)) @@ -282,6 +301,11 @@ class Attention(nn.Module): w = nn.Softmax(dim=-1)(w) w = self.attn_dropout(w) + + # Mask heads if we want to + if head_mask is not None: + w = w * head_mask + if self.output_attentions: return w, torch.matmul(w, v) return torch.matmul(w, v) @@ -299,13 +323,18 @@ class Attention(nn.Module): else: return x.permute(0, 2, 1, 3) - def forward(self, x): + def forward(self, x, head_mask=None): x = self.c_attn(x) query, key, value = x.split(self.split_size, dim=2) query = self.split_heads(query) key = self.split_heads(key, k=True) value = self.split_heads(value) - a = self._attn(query, key, value) + + a = self._attn(query, key, value, head_mask) + if self.keep_multihead_output: + self.multihead_output = a + self.multihead_output.retain_grad() + if self.output_attentions: attentions, a = a a = self.merge_heads(a) @@ -332,17 +361,17 @@ class MLP(nn.Module): class Block(nn.Module): - def __init__(self, n_ctx, config, scale=False, output_attentions=False): + def __init__(self, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False): super(Block, self).__init__() nx = config.n_embd self.output_attentions = output_attentions - self.attn = Attention(nx, n_ctx, config, scale, output_attentions) + self.attn = Attention(nx, n_ctx, config, scale, output_attentions, keep_multihead_output) self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.mlp = MLP(4 * nx, config) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) - def forward(self, x): - a = self.attn(x) + def forward(self, x, head_mask=None): + a = self.attn(x, head_mask=head_mask) if self.output_attentions: attentions, a = a n = self.ln_1(x + a) @@ -614,13 +643,14 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ``` """ - def __init__(self, config, output_attentions=False): + def __init__(self, config, output_attentions=False, keep_multihead_output=False): super(OpenAIGPTModel, self).__init__(config) self.output_attentions = output_attentions self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd) self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) self.drop = nn.Dropout(config.embd_pdrop) - block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions) + block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions, + keep_multihead_output=keep_multihead_output) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.apply(self.init_weights) @@ -639,7 +669,20 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): # Copy word embeddings from the previous weights self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :] - def forward(self, input_ids, position_ids=None, token_type_ids=None): + def prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + def get_multihead_outputs(self): + """ Gather all multi-head outputs. + Return: list (layers) of multihead module outputs with gradients + """ + return [h.attn.multihead_output for h in self.h] + + def forward(self, input_ids, position_ids=None, token_type_ids=None, head_mask=None): if position_ids is None: # This was used when we had a single embedding matrice from position and token embeddings # start = self.config.vocab_size + self.config.n_special @@ -648,6 +691,17 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + # Prepare head mask if needed + # 1.0 in head_mask indicate we mask the head + # attention_probs has shape bsz x n_heads x N x N + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each instance in batch + head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility + head_mask = (1.0 - head_mask) + input_shape = input_ids.size() input_ids = input_ids.view(-1, input_ids.size(-1)) position_ids = position_ids.view(-1, position_ids.size(-1)) @@ -664,11 +718,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): all_attentions = [] for block in self.h: + outputs = block(hidden_states, head_mask) if self.output_attentions: - attentions, hidden_states = block(hidden_states) + attentions, hidden_states = outputs all_attentions.append(attentions) else: - hidden_states = block(hidden_states) + hidden_states = outputs output_shape = input_shape + (hidden_states.size(-1),) if self.output_attentions: return all_attentions, hidden_states.view(*output_shape) @@ -731,9 +786,10 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ``` """ - def __init__(self, config, output_attentions=False): + def __init__(self, config, output_attentions=False, keep_multihead_output=False): super(OpenAIGPTLMHeadModel, self).__init__(config) - self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions) + self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions, + keep_multihead_output=keep_multihead_output) self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) self.apply(self.init_weights) @@ -745,8 +801,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): self.transformer.set_num_special_tokens(num_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens) - def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None): - hidden_states = self.transformer(input_ids, position_ids, token_type_ids) + def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None): + hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask) if self.transformer.output_attentions: all_attentions, hidden_states = hidden_states lm_logits = self.lm_head(hidden_states) @@ -825,9 +881,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ``` """ - def __init__(self, config, output_attentions=False): + def __init__(self, config, output_attentions=False, keep_multihead_output=False): super(OpenAIGPTDoubleHeadsModel, self).__init__(config) - self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions) + self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions, + keep_multihead_output=keep_multihead_output) self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config) self.apply(self.init_weights) @@ -840,8 +897,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): self.transformer.set_num_special_tokens(num_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens) - def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None): - hidden_states = self.transformer(input_ids, position_ids, token_type_ids) + def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, + position_ids=None, head_mask=None): + hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask) if self.transformer.output_attentions: all_attentions, hidden_states = hidden_states lm_logits = self.lm_head(hidden_states) diff --git a/tests/modeling_openai_test.py b/tests/modeling_openai_test.py index 4e7d9d542b..08353cdd18 100644 --- a/tests/modeling_openai_test.py +++ b/tests/modeling_openai_test.py @@ -182,6 +182,73 @@ class OpenAIGPTModelTest(unittest.TestCase): [list(l.size()) for l in result["loss"]], [[], []]) + def create_and_check_openai_for_headmasking(self, config, input_ids, token_type_ids, position_ids, + mc_labels, lm_labels, mc_token_ids): + for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel): + model = model_class(config=config, keep_multihead_output=True) + model.eval() + head_mask = torch.ones(self.n_head).to(input_ids.device) + head_mask[0] = 0.0 + head_mask[-1] = 0.0 # Mask all but the first and last heads + if isinstance(model, OpenAIGPTDoubleHeadsModel): + output = model(input_ids, mc_token_ids, head_mask=head_mask) + else: + output = model(input_ids, head_mask=head_mask) + + output = sum(t.sum() for t in output[:-1]) + output = output.sum() + output.backward() + multihead_outputs = (model if isinstance(model, OpenAIGPTModel) else model.transformer).get_multihead_outputs() + + self.parent.assertEqual(len(multihead_outputs), self.n_layer) + self.parent.assertListEqual( + list(multihead_outputs[0].size()), + [self.batch_size * self.n_choices, self.n_head, + self.seq_length, self.n_embd // self.n_head]) + self.parent.assertEqual( + len(multihead_outputs[0][:, 1:(self.n_head-1), :, :].nonzero()), + 0) + self.parent.assertEqual( + len(multihead_outputs[0][:, 0, :, :].nonzero()), + self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head) + self.parent.assertEqual( + len(multihead_outputs[0][:, self.n_head-1, :, :].nonzero()), + self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head) + + def create_and_check_openai_for_head_pruning(self, config, input_ids, token_type_ids, position_ids, + mc_labels, lm_labels, mc_token_ids): + for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel): + model = model_class(config=config, keep_multihead_output=True) + model.eval() + transformer = model if isinstance(model, OpenAIGPTModel) else model.transformer + heads_to_prune = {0: list(range(1, self.n_head)), + -1: [0]} + transformer.prune_heads(heads_to_prune) + if isinstance(model, OpenAIGPTDoubleHeadsModel): + output = model(input_ids, mc_token_ids) + else: + output = model(input_ids) + + output = sum(t.sum() for t in output[:-1]) + output = output.sum() + output.backward() + multihead_outputs = transformer.get_multihead_outputs() + + self.parent.assertEqual(len(multihead_outputs), self.n_layer) + self.parent.assertListEqual( + list(multihead_outputs[0].size()), + [self.batch_size * self.n_choices, 1, + self.seq_length, self.n_embd // self.n_head]) + self.parent.assertListEqual( + list(multihead_outputs[1].size()), + [self.batch_size * self.n_choices, self.n_head, + self.seq_length, self.n_embd // self.n_head]) + self.parent.assertListEqual( + list(multihead_outputs[-1].size()), + [self.batch_size * self.n_choices, self.n_head-1, + self.seq_length, self.n_embd // self.n_head]) + + def test_default(self): self.run_tester(OpenAIGPTModelTest.OpenAIGPTModelTester(self)) @@ -220,6 +287,9 @@ class OpenAIGPTModelTest(unittest.TestCase): tester.check_openai_double_heads_output(output_result) tester.check_openai_double_heads_loss_output(output_result) + tester.create_and_check_openai_for_headmasking(*config_and_inputs) + tester.create_and_check_openai_for_head_pruning(*config_and_inputs) + @classmethod def ids_tensor(cls, shape, vocab_size, rng=None, name=None): """Creates a random int32 tensor of the shape within the vocab size."""