From 42e00cf9e1969973a563db2900ed86bbf58dbc71 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 19 Aug 2019 22:43:02 -0400 Subject: [PATCH] Pruning saved to configuration first try --- pytorch_transformers/modeling_bert.py | 6 ++ pytorch_transformers/modeling_utils.py | 10 ++++ .../tests/modeling_common_test.py | 56 +++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index f918afff3e..4a68c2b96b 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -649,6 +649,12 @@ class BertModel(BertPreTrainedModel): self.encoder = BertEncoder(config) self.pooler = BertPooler(config) + if hasattr(config, "pruned_heads"): + pruned_heads = config.pruned_heads.copy().items() + for layer, heads in pruned_heads: + if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads: + self.prune_heads({int(layer): list(map(int, heads))}) + self.apply(self.init_weights) def _resize_token_embeddings(self, new_num_tokens): diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 0d4fce67f0..351fbfd0e1 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -104,6 +104,7 @@ class PretrainedConfig(object): self.output_attentions = kwargs.pop('output_attentions', False) self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.torchscript = kwargs.pop('torchscript', False) + self.pruned_heads = kwargs.pop('pruned_heads', {}) def save_pretrained(self, save_directory): """ Save a configuration object to the directory `save_directory`, so that it @@ -363,6 +364,15 @@ class PreTrainedModel(nn.Module): heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). """ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed + + for layer, heads in heads_to_prune.items(): + if str(layer) not in self.config.pruned_heads: + self.config.pruned_heads[str(layer)] = heads + else: + for head in heads: + if head not in self.config.pruned_heads[str(layer)]: + self.config.pruned_heads[str(layer)].append(head) + base_model._prune_heads(heads_to_prune) def save_pretrained(self, save_directory): diff --git a/pytorch_transformers/tests/modeling_common_test.py b/pytorch_transformers/tests/modeling_common_test.py index 8b9a2ffd17..7ed1eddbfb 100644 --- a/pytorch_transformers/tests/modeling_common_test.py +++ b/pytorch_transformers/tests/modeling_common_test.py @@ -219,6 +219,7 @@ class CommonTestCases: del inputs_dict["head_mask"] for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.output_attentions = True config.output_hidden_states = False model = model_class(config=config) @@ -237,6 +238,61 @@ class CommonTestCases: self.assertEqual( attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) + def test_head_pruning_save_load_from_pretrained(self): + if not self.test_pruning: + return + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_attentions = True + config.output_hidden_states = False + model = model_class(config=config) + model.eval() + heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), + -1: [0]} + model.prune_heads(heads_to_prune) + directory = "pruned_model" + if not os.path.exists(directory): + os.makedirs(directory) + model.save_pretrained(directory) + model = model_class.from_pretrained(directory) + + outputs = model(**inputs_dict) + attentions = outputs[-1] + self.assertEqual( + attentions[0].shape[-3], 1) + self.assertEqual( + attentions[1].shape[-3], self.model_tester.num_attention_heads) + self.assertEqual( + attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) + + shutil.rmtree(directory) + + def test_head_pruning_save_load_from_config_init(self): + print(self.test_pruning) + if not self.test_pruning: + return + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_attentions = True + config.output_hidden_states = False + + heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), + -1: [0]} + config.pruned_heads = heads_to_prune + + model = model_class(config=config) + model.eval() + + outputs = model(**inputs_dict) + attentions = outputs[-1] + self.assertEqual( + attentions[0].shape[-3], 1) + self.assertEqual( + attentions[1].shape[-3], self.model_tester.num_attention_heads) + self.assertEqual( + attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) def test_hidden_states_output(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()