From 87747518e94860e730606848e6a8d2ed68ae8a51 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 21 Aug 2019 21:20:39 -0400 Subject: [PATCH] Blocks deletion from already deleted heads. Necessary integration test. Now raises a warning when a head to be deleted already has been deleted. An integration test verifying the total pipeline (-> from config -> save model -> load model -> additional head pruning) has been added. --- pytorch_transformers/modeling_bert.py | 1 + pytorch_transformers/modeling_gpt2.py | 1 + pytorch_transformers/modeling_openai.py | 1 + pytorch_transformers/modeling_utils.py | 21 +++-- pytorch_transformers/modeling_xlm.py | 1 + .../tests/modeling_common_test.py | 76 ++++++++++++++++--- 6 files changed, 84 insertions(+), 17 deletions(-) diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index 4a68c2b96b..5a65e442d0 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -651,6 +651,7 @@ class BertModel(BertPreTrainedModel): if hasattr(config, "pruned_heads"): pruned_heads = config.pruned_heads.copy().items() + config.pruned_heads = {} for layer, heads in pruned_heads: if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads: self.prune_heads({int(layer): list(map(int, heads))}) diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index 23cc7f5313..8aa5347c71 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -455,6 +455,7 @@ class GPT2Model(GPT2PreTrainedModel): if hasattr(config, "pruned_heads"): pruned_heads = config.pruned_heads.copy().items() + config.pruned_heads = {} for layer, heads in pruned_heads: if self.h[int(layer)].attn.n_head == config.n_head: self.prune_heads({int(layer): list(map(int, heads))}) diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index c640b7c86c..ce3768c676 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -458,6 +458,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): if hasattr(config, "pruned_heads"): pruned_heads = config.pruned_heads.copy().items() + config.pruned_heads = {} for layer, heads in pruned_heads: if self.h[int(layer)].attn.n_head == config.n_head: self.prune_heads({int(layer): list(map(int, heads))}) diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 351fbfd0e1..0a47d07fd4 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -201,6 +201,10 @@ class PretrainedConfig(object): # Load config config = cls.from_json_file(resolved_config_file) + if hasattr(config, 'pruned_heads'): + config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()} + + # Update config with kwargs if needed to_remove = [] for key, value in kwargs.items(): @@ -365,15 +369,22 @@ class PreTrainedModel(nn.Module): """ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed + to_be_pruned = {} + for layer, heads in heads_to_prune.items(): - if str(layer) not in self.config.pruned_heads: - self.config.pruned_heads[str(layer)] = heads + if int(layer) not in self.config.pruned_heads: + self.config.pruned_heads[int(layer)] = heads + to_be_pruned[int(layer)] = heads else: for head in heads: - if head not in self.config.pruned_heads[str(layer)]: - self.config.pruned_heads[str(layer)].append(head) + if head not in self.config.pruned_heads[int(layer)]: + self.config.pruned_heads[int(layer)].append(head) + to_be_pruned[int(layer)].append(head) + else: + logger.warning(f"Tried to remove head {head} of layer {layer} but it was already removed. " + f"The removed heads are {heads_to_prune}") - base_model._prune_heads(heads_to_prune) + base_model._prune_heads(to_be_pruned) def save_pretrained(self, save_directory): """ Save a model and its configuration file to a directory, so that it diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index cf121eee41..1e0f8d7c77 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -561,6 +561,7 @@ class XLMModel(XLMPreTrainedModel): if hasattr(config, "pruned_heads"): pruned_heads = config.pruned_heads.copy().items() + config.pruned_heads = {} for layer, heads in pruned_heads: if self.attentions[int(layer)].n_heads == config.n_heads: self.prune_heads({int(layer): list(map(int, heads))}) diff --git a/pytorch_transformers/tests/modeling_common_test.py b/pytorch_transformers/tests/modeling_common_test.py index c06c501153..8b1a70fcf3 100644 --- a/pytorch_transformers/tests/modeling_common_test.py +++ b/pytorch_transformers/tests/modeling_common_test.py @@ -262,12 +262,9 @@ class CommonTestCases: outputs = model(**inputs_dict) attentions = outputs[-1] - self.assertEqual( - attentions[0].shape[-3], 1) - self.assertEqual( - attentions[1].shape[-3], self.model_tester.num_attention_heads) - self.assertEqual( - attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) + self.assertEqual(attentions[0].shape[-3], 1) + self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads) + self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) shutil.rmtree(directory) @@ -293,12 +290,67 @@ class CommonTestCases: outputs = model(**inputs_dict) attentions = outputs[-1] - self.assertEqual( - attentions[0].shape[-3], 1) - self.assertEqual( - attentions[1].shape[-3], self.model_tester.num_attention_heads) - self.assertEqual( - attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) + + self.assertEqual(attentions[0].shape[-3], 1) + self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads) + self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) + + def test_head_pruning_integration(self): + if not self.test_pruning: + return + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + if "head_mask" in inputs_dict: + del inputs_dict["head_mask"] + + config.output_attentions = True + config.output_hidden_states = False + + heads_to_prune = {0: [0], 1: [1, 2]} + config.pruned_heads = heads_to_prune + + model = model_class(config=config) + model.eval() + + outputs = model(**inputs_dict) + attentions = outputs[-1] + + self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1) + self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2) + self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads) + self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads) + + directory = "pruned_model" + + if not os.path.exists(directory): + os.makedirs(directory) + model.save_pretrained(directory) + model = model_class.from_pretrained(directory) + shutil.rmtree(directory) + + outputs = model(**inputs_dict) + attentions = outputs[-1] + + self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1) + self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2) + self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads) + self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads) + + heads_to_prune = {0: [0], 2: [1, 2]} + model.prune_heads(heads_to_prune) + + outputs = model(**inputs_dict) + attentions = outputs[-1] + + self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads -1) + self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2) + self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads - 2) + self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads) + + self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]}) + def test_hidden_states_output(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()