From c85b5db61a8825edda59a0e9f12bc1be08c63cdc Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 21 Aug 2019 21:37:30 -0400 Subject: [PATCH] Conditional append/init + fixed warning --- pytorch_transformers/modeling_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 5a89badba6..c69cba49e3 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -379,11 +379,15 @@ class PreTrainedModel(nn.Module): for head in heads: if head not in self.config.pruned_heads[int(layer)]: self.config.pruned_heads[int(layer)].append(head) - to_be_pruned[int(layer)].append(head) + + if int(layer) in to_be_pruned: + to_be_pruned[int(layer)].append(head) + else: + to_be_pruned[int(layer)] = [head] else: - logger.warning("Tried to remove head " + head + - " of layer " + layer + - " but it was already removed. The current removed heads are " + heads_to_prune) + logger.warning("Tried to remove head " + str(head) + + " of layer " + str(layer) + + " but it was already removed. The current removed heads are " + str(heads_to_prune)) base_model._prune_heads(to_be_pruned)