Pruning saved to configuration first try

This commit is contained in:
Lysandre
2019-08-19 22:43:02 -04:00
committed by LysandreJik
parent d7a4c3252e
commit 42e00cf9e1
3 changed files with 72 additions and 0 deletions

View File

@@ -649,6 +649,12 @@ class BertModel(BertPreTrainedModel):
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
for layer, heads in pruned_heads:
if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads:
self.prune_heads({int(layer): list(map(int, heads))})
self.apply(self.init_weights)
def _resize_token_embeddings(self, new_num_tokens):