From 719cb3738d442431d246c107899b40441c3dd5ae Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 21 Aug 2019 20:12:06 -0400 Subject: [PATCH] Pruning for GPT and GPT-2 --- pytorch_transformers/modeling_gpt2.py | 6 ++++++ pytorch_transformers/modeling_openai.py | 6 ++++++ .../tests/modeling_common_test.py | 17 ++++++++++++----- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index 283dc68a6a..23cc7f5313 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -453,6 +453,12 @@ class GPT2Model(GPT2PreTrainedModel): self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + if hasattr(config, "pruned_heads"): + pruned_heads = config.pruned_heads.copy().items() + for layer, heads in pruned_heads: + if self.h[int(layer)].attn.n_head == config.n_head: + self.prune_heads({int(layer): list(map(int, heads))}) + self.apply(self.init_weights) def _resize_token_embeddings(self, new_num_tokens): diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index 690aa7812b..c640b7c86c 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -456,6 +456,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) + if hasattr(config, "pruned_heads"): + pruned_heads = config.pruned_heads.copy().items() + for layer, heads in pruned_heads: + if self.h[int(layer)].attn.n_head == config.n_head: + self.prune_heads({int(layer): list(map(int, heads))}) + self.apply(self.init_weights) def _resize_token_embeddings(self, new_num_tokens): diff --git a/pytorch_transformers/tests/modeling_common_test.py b/pytorch_transformers/tests/modeling_common_test.py index dbb041ab05..c06c501153 100644 --- a/pytorch_transformers/tests/modeling_common_test.py +++ b/pytorch_transformers/tests/modeling_common_test.py @@ -213,13 +213,12 @@ class CommonTestCases: if not self.test_pruning: return - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - if "head_mask" in inputs_dict: - del inputs_dict["head_mask"] - for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + if "head_mask" in inputs_dict: + del inputs_dict["head_mask"] + config.output_attentions = True config.output_hidden_states = False model = model_class(config=config) @@ -244,6 +243,10 @@ class CommonTestCases: for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + if "head_mask" in inputs_dict: + del inputs_dict["head_mask"] + config.output_attentions = True config.output_hidden_states = False model = model_class(config=config) @@ -274,6 +277,10 @@ class CommonTestCases: for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + if "head_mask" in inputs_dict: + del inputs_dict["head_mask"] + config.output_attentions = True config.output_hidden_states = False