Blocks deletion from already deleted heads. Necessary integration test.
Now raises a warning when a head to be deleted already has been deleted. An integration test verifying the total pipeline (-> from config -> save model -> load model -> additional head pruning) has been added.
This commit is contained in:
@@ -651,6 +651,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
|
|
||||||
if hasattr(config, "pruned_heads"):
|
if hasattr(config, "pruned_heads"):
|
||||||
pruned_heads = config.pruned_heads.copy().items()
|
pruned_heads = config.pruned_heads.copy().items()
|
||||||
|
config.pruned_heads = {}
|
||||||
for layer, heads in pruned_heads:
|
for layer, heads in pruned_heads:
|
||||||
if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads:
|
if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads:
|
||||||
self.prune_heads({int(layer): list(map(int, heads))})
|
self.prune_heads({int(layer): list(map(int, heads))})
|
||||||
|
|||||||
@@ -455,6 +455,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
|
|
||||||
if hasattr(config, "pruned_heads"):
|
if hasattr(config, "pruned_heads"):
|
||||||
pruned_heads = config.pruned_heads.copy().items()
|
pruned_heads = config.pruned_heads.copy().items()
|
||||||
|
config.pruned_heads = {}
|
||||||
for layer, heads in pruned_heads:
|
for layer, heads in pruned_heads:
|
||||||
if self.h[int(layer)].attn.n_head == config.n_head:
|
if self.h[int(layer)].attn.n_head == config.n_head:
|
||||||
self.prune_heads({int(layer): list(map(int, heads))})
|
self.prune_heads({int(layer): list(map(int, heads))})
|
||||||
|
|||||||
@@ -458,6 +458,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
|
|
||||||
if hasattr(config, "pruned_heads"):
|
if hasattr(config, "pruned_heads"):
|
||||||
pruned_heads = config.pruned_heads.copy().items()
|
pruned_heads = config.pruned_heads.copy().items()
|
||||||
|
config.pruned_heads = {}
|
||||||
for layer, heads in pruned_heads:
|
for layer, heads in pruned_heads:
|
||||||
if self.h[int(layer)].attn.n_head == config.n_head:
|
if self.h[int(layer)].attn.n_head == config.n_head:
|
||||||
self.prune_heads({int(layer): list(map(int, heads))})
|
self.prune_heads({int(layer): list(map(int, heads))})
|
||||||
|
|||||||
@@ -201,6 +201,10 @@ class PretrainedConfig(object):
|
|||||||
# Load config
|
# Load config
|
||||||
config = cls.from_json_file(resolved_config_file)
|
config = cls.from_json_file(resolved_config_file)
|
||||||
|
|
||||||
|
if hasattr(config, 'pruned_heads'):
|
||||||
|
config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
|
||||||
|
|
||||||
|
|
||||||
# Update config with kwargs if needed
|
# Update config with kwargs if needed
|
||||||
to_remove = []
|
to_remove = []
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
@@ -365,15 +369,22 @@ class PreTrainedModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||||
|
|
||||||
|
to_be_pruned = {}
|
||||||
|
|
||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
if str(layer) not in self.config.pruned_heads:
|
if int(layer) not in self.config.pruned_heads:
|
||||||
self.config.pruned_heads[str(layer)] = heads
|
self.config.pruned_heads[int(layer)] = heads
|
||||||
|
to_be_pruned[int(layer)] = heads
|
||||||
else:
|
else:
|
||||||
for head in heads:
|
for head in heads:
|
||||||
if head not in self.config.pruned_heads[str(layer)]:
|
if head not in self.config.pruned_heads[int(layer)]:
|
||||||
self.config.pruned_heads[str(layer)].append(head)
|
self.config.pruned_heads[int(layer)].append(head)
|
||||||
|
to_be_pruned[int(layer)].append(head)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Tried to remove head {head} of layer {layer} but it was already removed. "
|
||||||
|
f"The removed heads are {heads_to_prune}")
|
||||||
|
|
||||||
base_model._prune_heads(heads_to_prune)
|
base_model._prune_heads(to_be_pruned)
|
||||||
|
|
||||||
def save_pretrained(self, save_directory):
|
def save_pretrained(self, save_directory):
|
||||||
""" Save a model and its configuration file to a directory, so that it
|
""" Save a model and its configuration file to a directory, so that it
|
||||||
|
|||||||
@@ -561,6 +561,7 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
|
|
||||||
if hasattr(config, "pruned_heads"):
|
if hasattr(config, "pruned_heads"):
|
||||||
pruned_heads = config.pruned_heads.copy().items()
|
pruned_heads = config.pruned_heads.copy().items()
|
||||||
|
config.pruned_heads = {}
|
||||||
for layer, heads in pruned_heads:
|
for layer, heads in pruned_heads:
|
||||||
if self.attentions[int(layer)].n_heads == config.n_heads:
|
if self.attentions[int(layer)].n_heads == config.n_heads:
|
||||||
self.prune_heads({int(layer): list(map(int, heads))})
|
self.prune_heads({int(layer): list(map(int, heads))})
|
||||||
|
|||||||
@@ -262,12 +262,9 @@ class CommonTestCases:
|
|||||||
|
|
||||||
outputs = model(**inputs_dict)
|
outputs = model(**inputs_dict)
|
||||||
attentions = outputs[-1]
|
attentions = outputs[-1]
|
||||||
self.assertEqual(
|
self.assertEqual(attentions[0].shape[-3], 1)
|
||||||
attentions[0].shape[-3], 1)
|
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
||||||
self.assertEqual(
|
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
|
||||||
self.assertEqual(
|
|
||||||
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
|
||||||
|
|
||||||
shutil.rmtree(directory)
|
shutil.rmtree(directory)
|
||||||
|
|
||||||
@@ -293,12 +290,67 @@ class CommonTestCases:
|
|||||||
|
|
||||||
outputs = model(**inputs_dict)
|
outputs = model(**inputs_dict)
|
||||||
attentions = outputs[-1]
|
attentions = outputs[-1]
|
||||||
self.assertEqual(
|
|
||||||
attentions[0].shape[-3], 1)
|
self.assertEqual(attentions[0].shape[-3], 1)
|
||||||
self.assertEqual(
|
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
||||||
attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
self.assertEqual(
|
|
||||||
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
def test_head_pruning_integration(self):
|
||||||
|
if not self.test_pruning:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
if "head_mask" in inputs_dict:
|
||||||
|
del inputs_dict["head_mask"]
|
||||||
|
|
||||||
|
config.output_attentions = True
|
||||||
|
config.output_hidden_states = False
|
||||||
|
|
||||||
|
heads_to_prune = {0: [0], 1: [1, 2]}
|
||||||
|
config.pruned_heads = heads_to_prune
|
||||||
|
|
||||||
|
model = model_class(config=config)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
outputs = model(**inputs_dict)
|
||||||
|
attentions = outputs[-1]
|
||||||
|
|
||||||
|
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
|
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
|
||||||
|
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
|
||||||
|
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
|
||||||
|
|
||||||
|
directory = "pruned_model"
|
||||||
|
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
os.makedirs(directory)
|
||||||
|
model.save_pretrained(directory)
|
||||||
|
model = model_class.from_pretrained(directory)
|
||||||
|
shutil.rmtree(directory)
|
||||||
|
|
||||||
|
outputs = model(**inputs_dict)
|
||||||
|
attentions = outputs[-1]
|
||||||
|
|
||||||
|
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
|
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
|
||||||
|
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
|
||||||
|
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
|
||||||
|
|
||||||
|
heads_to_prune = {0: [0], 2: [1, 2]}
|
||||||
|
model.prune_heads(heads_to_prune)
|
||||||
|
|
||||||
|
outputs = model(**inputs_dict)
|
||||||
|
attentions = outputs[-1]
|
||||||
|
|
||||||
|
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads -1)
|
||||||
|
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
|
||||||
|
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads - 2)
|
||||||
|
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
|
||||||
|
|
||||||
|
self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]})
|
||||||
|
|
||||||
|
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user