Blocks deletion from already deleted heads. Necessary integration test.
Now raises a warning when a head to be deleted already has been deleted. An integration test verifying the total pipeline (-> from config -> save model -> load model -> additional head pruning) has been added.
This commit is contained in:
@@ -651,6 +651,7 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
if hasattr(config, "pruned_heads"):
|
||||
pruned_heads = config.pruned_heads.copy().items()
|
||||
config.pruned_heads = {}
|
||||
for layer, heads in pruned_heads:
|
||||
if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads:
|
||||
self.prune_heads({int(layer): list(map(int, heads))})
|
||||
|
||||
@@ -455,6 +455,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
|
||||
if hasattr(config, "pruned_heads"):
|
||||
pruned_heads = config.pruned_heads.copy().items()
|
||||
config.pruned_heads = {}
|
||||
for layer, heads in pruned_heads:
|
||||
if self.h[int(layer)].attn.n_head == config.n_head:
|
||||
self.prune_heads({int(layer): list(map(int, heads))})
|
||||
|
||||
@@ -458,6 +458,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
if hasattr(config, "pruned_heads"):
|
||||
pruned_heads = config.pruned_heads.copy().items()
|
||||
config.pruned_heads = {}
|
||||
for layer, heads in pruned_heads:
|
||||
if self.h[int(layer)].attn.n_head == config.n_head:
|
||||
self.prune_heads({int(layer): list(map(int, heads))})
|
||||
|
||||
@@ -201,6 +201,10 @@ class PretrainedConfig(object):
|
||||
# Load config
|
||||
config = cls.from_json_file(resolved_config_file)
|
||||
|
||||
if hasattr(config, 'pruned_heads'):
|
||||
config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
|
||||
|
||||
|
||||
# Update config with kwargs if needed
|
||||
to_remove = []
|
||||
for key, value in kwargs.items():
|
||||
@@ -365,15 +369,22 @@ class PreTrainedModel(nn.Module):
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
|
||||
to_be_pruned = {}
|
||||
|
||||
for layer, heads in heads_to_prune.items():
|
||||
if str(layer) not in self.config.pruned_heads:
|
||||
self.config.pruned_heads[str(layer)] = heads
|
||||
if int(layer) not in self.config.pruned_heads:
|
||||
self.config.pruned_heads[int(layer)] = heads
|
||||
to_be_pruned[int(layer)] = heads
|
||||
else:
|
||||
for head in heads:
|
||||
if head not in self.config.pruned_heads[str(layer)]:
|
||||
self.config.pruned_heads[str(layer)].append(head)
|
||||
if head not in self.config.pruned_heads[int(layer)]:
|
||||
self.config.pruned_heads[int(layer)].append(head)
|
||||
to_be_pruned[int(layer)].append(head)
|
||||
else:
|
||||
logger.warning(f"Tried to remove head {head} of layer {layer} but it was already removed. "
|
||||
f"The removed heads are {heads_to_prune}")
|
||||
|
||||
base_model._prune_heads(heads_to_prune)
|
||||
base_model._prune_heads(to_be_pruned)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a model and its configuration file to a directory, so that it
|
||||
|
||||
@@ -561,6 +561,7 @@ class XLMModel(XLMPreTrainedModel):
|
||||
|
||||
if hasattr(config, "pruned_heads"):
|
||||
pruned_heads = config.pruned_heads.copy().items()
|
||||
config.pruned_heads = {}
|
||||
for layer, heads in pruned_heads:
|
||||
if self.attentions[int(layer)].n_heads == config.n_heads:
|
||||
self.prune_heads({int(layer): list(map(int, heads))})
|
||||
|
||||
@@ -262,12 +262,9 @@ class CommonTestCases:
|
||||
|
||||
outputs = model(**inputs_dict)
|
||||
attentions = outputs[-1]
|
||||
self.assertEqual(
|
||||
attentions[0].shape[-3], 1)
|
||||
self.assertEqual(
|
||||
attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
||||
self.assertEqual(
|
||||
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||
self.assertEqual(attentions[0].shape[-3], 1)
|
||||
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
||||
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||
|
||||
shutil.rmtree(directory)
|
||||
|
||||
@@ -293,12 +290,67 @@ class CommonTestCases:
|
||||
|
||||
outputs = model(**inputs_dict)
|
||||
attentions = outputs[-1]
|
||||
self.assertEqual(
|
||||
attentions[0].shape[-3], 1)
|
||||
self.assertEqual(
|
||||
attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
||||
self.assertEqual(
|
||||
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||
|
||||
self.assertEqual(attentions[0].shape[-3], 1)
|
||||
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
||||
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||
|
||||
def test_head_pruning_integration(self):
|
||||
if not self.test_pruning:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
|
||||
config.output_attentions = True
|
||||
config.output_hidden_states = False
|
||||
|
||||
heads_to_prune = {0: [0], 1: [1, 2]}
|
||||
config.pruned_heads = heads_to_prune
|
||||
|
||||
model = model_class(config=config)
|
||||
model.eval()
|
||||
|
||||
outputs = model(**inputs_dict)
|
||||
attentions = outputs[-1]
|
||||
|
||||
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
|
||||
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
|
||||
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
|
||||
|
||||
directory = "pruned_model"
|
||||
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
model.save_pretrained(directory)
|
||||
model = model_class.from_pretrained(directory)
|
||||
shutil.rmtree(directory)
|
||||
|
||||
outputs = model(**inputs_dict)
|
||||
attentions = outputs[-1]
|
||||
|
||||
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
|
||||
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
|
||||
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
|
||||
|
||||
heads_to_prune = {0: [0], 2: [1, 2]}
|
||||
model.prune_heads(heads_to_prune)
|
||||
|
||||
outputs = model(**inputs_dict)
|
||||
attentions = outputs[-1]
|
||||
|
||||
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads -1)
|
||||
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
|
||||
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads - 2)
|
||||
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
|
||||
|
||||
self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]})
|
||||
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user