Pruning saved to configuration first try
This commit is contained in:
@@ -649,6 +649,12 @@ class BertModel(BertPreTrainedModel):
|
|||||||
self.encoder = BertEncoder(config)
|
self.encoder = BertEncoder(config)
|
||||||
self.pooler = BertPooler(config)
|
self.pooler = BertPooler(config)
|
||||||
|
|
||||||
|
if hasattr(config, "pruned_heads"):
|
||||||
|
pruned_heads = config.pruned_heads.copy().items()
|
||||||
|
for layer, heads in pruned_heads:
|
||||||
|
if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads:
|
||||||
|
self.prune_heads({int(layer): list(map(int, heads))})
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
|
|||||||
@@ -104,6 +104,7 @@ class PretrainedConfig(object):
|
|||||||
self.output_attentions = kwargs.pop('output_attentions', False)
|
self.output_attentions = kwargs.pop('output_attentions', False)
|
||||||
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
||||||
self.torchscript = kwargs.pop('torchscript', False)
|
self.torchscript = kwargs.pop('torchscript', False)
|
||||||
|
self.pruned_heads = kwargs.pop('pruned_heads', {})
|
||||||
|
|
||||||
def save_pretrained(self, save_directory):
|
def save_pretrained(self, save_directory):
|
||||||
""" Save a configuration object to the directory `save_directory`, so that it
|
""" Save a configuration object to the directory `save_directory`, so that it
|
||||||
@@ -363,6 +364,15 @@ class PreTrainedModel(nn.Module):
|
|||||||
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
|
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
|
||||||
"""
|
"""
|
||||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||||
|
|
||||||
|
for layer, heads in heads_to_prune.items():
|
||||||
|
if str(layer) not in self.config.pruned_heads:
|
||||||
|
self.config.pruned_heads[str(layer)] = heads
|
||||||
|
else:
|
||||||
|
for head in heads:
|
||||||
|
if head not in self.config.pruned_heads[str(layer)]:
|
||||||
|
self.config.pruned_heads[str(layer)].append(head)
|
||||||
|
|
||||||
base_model._prune_heads(heads_to_prune)
|
base_model._prune_heads(heads_to_prune)
|
||||||
|
|
||||||
def save_pretrained(self, save_directory):
|
def save_pretrained(self, save_directory):
|
||||||
|
|||||||
@@ -219,6 +219,7 @@ class CommonTestCases:
|
|||||||
del inputs_dict["head_mask"]
|
del inputs_dict["head_mask"]
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
config.output_attentions = True
|
config.output_attentions = True
|
||||||
config.output_hidden_states = False
|
config.output_hidden_states = False
|
||||||
model = model_class(config=config)
|
model = model_class(config=config)
|
||||||
@@ -237,6 +238,61 @@ class CommonTestCases:
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
|
|
||||||
|
def test_head_pruning_save_load_from_pretrained(self):
|
||||||
|
if not self.test_pruning:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.output_attentions = True
|
||||||
|
config.output_hidden_states = False
|
||||||
|
model = model_class(config=config)
|
||||||
|
model.eval()
|
||||||
|
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
|
||||||
|
-1: [0]}
|
||||||
|
model.prune_heads(heads_to_prune)
|
||||||
|
directory = "pruned_model"
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
os.makedirs(directory)
|
||||||
|
model.save_pretrained(directory)
|
||||||
|
model = model_class.from_pretrained(directory)
|
||||||
|
|
||||||
|
outputs = model(**inputs_dict)
|
||||||
|
attentions = outputs[-1]
|
||||||
|
self.assertEqual(
|
||||||
|
attentions[0].shape[-3], 1)
|
||||||
|
self.assertEqual(
|
||||||
|
attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
||||||
|
self.assertEqual(
|
||||||
|
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
|
|
||||||
|
shutil.rmtree(directory)
|
||||||
|
|
||||||
|
def test_head_pruning_save_load_from_config_init(self):
|
||||||
|
print(self.test_pruning)
|
||||||
|
if not self.test_pruning:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.output_attentions = True
|
||||||
|
config.output_hidden_states = False
|
||||||
|
|
||||||
|
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
|
||||||
|
-1: [0]}
|
||||||
|
config.pruned_heads = heads_to_prune
|
||||||
|
|
||||||
|
model = model_class(config=config)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
outputs = model(**inputs_dict)
|
||||||
|
attentions = outputs[-1]
|
||||||
|
self.assertEqual(
|
||||||
|
attentions[0].shape[-3], 1)
|
||||||
|
self.assertEqual(
|
||||||
|
attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
||||||
|
self.assertEqual(
|
||||||
|
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
|
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user