From a139d1a1602ee72ca98d5e0412efbd68f746d2c8 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 8 Jun 2020 17:08:04 -0400 Subject: [PATCH] [cleanup] consolidate some prune_heads logic (#4799) --- src/transformers/modeling_albert.py | 13 ++++--------- src/transformers/modeling_bert.py | 13 ++++--------- src/transformers/modeling_distilbert.py | 10 ++-------- src/transformers/modeling_gpt2.py | 19 ++++++++++--------- src/transformers/modeling_openai.py | 18 ++++++++++-------- src/transformers/modeling_t5.py | 10 ++-------- src/transformers/modeling_utils.py | 14 ++++++++++++++ src/transformers/modeling_xlm.py | 16 ++++++++-------- 8 files changed, 54 insertions(+), 59 deletions(-) diff --git a/src/transformers/modeling_albert.py b/src/transformers/modeling_albert.py index a6c54970f0..29d0b58226 100644 --- a/src/transformers/modeling_albert.py +++ b/src/transformers/modeling_albert.py @@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss, MSELoss from .configuration_albert import AlbertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer -from .modeling_utils import PreTrainedModel +from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices logger = logging.getLogger(__name__) @@ -199,14 +199,9 @@ class AlbertAttention(BertSelfAttention): def prune_heads(self, heads): if len(heads) == 0: return - mask = torch.ones(self.num_attention_heads, self.attention_head_size) - heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads - for head in heads: - # Compute how many pruned heads are before the head and move the index accordingly - head = head - sum(1 if h < head else 0 for h in self.pruned_heads) - mask[head] = 0 - mask = mask.view(-1).contiguous().eq(1) - index = torch.arange(len(mask))[mask].long() + heads, index = find_pruneable_heads_and_indices( + heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads + ) # Prune linear layers self.query = prune_linear_layer(self.query, index) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 82e8df0abb..733b1e8898 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss, MSELoss from .activations import gelu, gelu_new, swish from .configuration_bert import BertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_utils import PreTrainedModel, prune_linear_layer +from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer logger = logging.getLogger(__name__) @@ -284,14 +284,9 @@ class BertAttention(nn.Module): def prune_heads(self, heads): if len(heads) == 0: return - mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) - heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads - for head in heads: - # Compute how many pruned heads are before the head and move the index accordingly - head = head - sum(1 if h < head else 0 for h in self.pruned_heads) - mask[head] = 0 - mask = mask.view(-1).contiguous().eq(1) - index = torch.arange(len(mask))[mask].long() + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) diff --git a/src/transformers/modeling_distilbert.py b/src/transformers/modeling_distilbert.py index 281553616e..1105260afb 100644 --- a/src/transformers/modeling_distilbert.py +++ b/src/transformers/modeling_distilbert.py @@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss from .activations import gelu from .configuration_distilbert import DistilBertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_utils import PreTrainedModel, prune_linear_layer +from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer logger = logging.getLogger(__name__) @@ -120,13 +120,7 @@ class MultiHeadSelfAttention(nn.Module): attention_head_size = self.dim // self.n_heads if len(heads) == 0: return - mask = torch.ones(self.n_heads, attention_head_size) - heads = set(heads) - self.pruned_heads - for head in heads: - head -= sum(1 if h < head else 0 for h in self.pruned_heads) - mask[head] = 0 - mask = mask.view(-1).contiguous().eq(1) - index = torch.arange(len(mask))[mask].long() + heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads) # Prune linear layers self.q_lin = prune_linear_layer(self.q_lin, index) self.k_lin = prune_linear_layer(self.k_lin, index) diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index cc9c89cd39..a8184f4946 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -27,7 +27,13 @@ from torch.nn import CrossEntropyLoss from .activations import ACT2FN from .configuration_gpt2 import GPT2Config from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer +from .modeling_utils import ( + Conv1D, + PreTrainedModel, + SequenceSummary, + find_pruneable_heads_and_indices, + prune_conv1d_layer, +) logger = logging.getLogger(__name__) @@ -122,14 +128,9 @@ class Attention(nn.Module): def prune_heads(self, heads): if len(heads) == 0: return - mask = torch.ones(self.n_head, self.split_size // self.n_head) - heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads - for head in heads: - # Compute how many pruned heads are before the head and move the index accordingly - head = head - sum(1 if h < head else 0 for h in self.pruned_heads) - mask[head] = 0 - mask = mask.view(-1).contiguous().eq(1) - index = torch.arange(len(mask))[mask].long() + heads, index = find_pruneable_heads_and_indices( + heads, self.n_head, self.split_size // self.n_head, self.pruned_heads + ) index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) # Prune conv1d layers diff --git a/src/transformers/modeling_openai.py b/src/transformers/modeling_openai.py index a1c729ac69..ab27ad7c17 100644 --- a/src/transformers/modeling_openai.py +++ b/src/transformers/modeling_openai.py @@ -29,7 +29,13 @@ from torch.nn import CrossEntropyLoss from .activations import gelu_new, swish from .configuration_openai import OpenAIGPTConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer +from .modeling_utils import ( + Conv1D, + PreTrainedModel, + SequenceSummary, + find_pruneable_heads_and_indices, + prune_conv1d_layer, +) logger = logging.getLogger(__name__) @@ -142,13 +148,9 @@ class Attention(nn.Module): def prune_heads(self, heads): if len(heads) == 0: return - mask = torch.ones(self.n_head, self.split_size // self.n_head) - heads = set(heads) - self.pruned_heads - for head in heads: - head -= sum(1 if h < head else 0 for h in self.pruned_heads) - mask[head] = 0 - mask = mask.view(-1).contiguous().eq(1) - index = torch.arange(len(mask))[mask].long() + heads, index = find_pruneable_heads_and_indices( + heads, self.n_head, self.split_size // self.n_head, self.pruned_heads + ) index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) # Prune conv1d layers self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index ccc005571b..8f1a3ea49a 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss from .configuration_t5 import T5Config from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable -from .modeling_utils import PreTrainedModel, prune_linear_layer +from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer logger = logging.getLogger(__name__) @@ -216,13 +216,7 @@ class T5Attention(nn.Module): def prune_heads(self, heads): if len(heads) == 0: return - mask = torch.ones(self.n_heads, self.d_kv) - heads = set(heads) - self.pruned_heads - for head in heads: - head -= sum(1 if h < head else 0 for h in self.pruned_heads) - mask[head] = 0 - mask = mask.view(-1).contiguous().eq(1) - index = torch.arange(len(mask))[mask].long() + heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, self.d_kv, self.pruned_heads) # Prune linear layers self.q = prune_linear_layer(self.q, index) self.k = prune_linear_layer(self.k, index) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index edd11b174d..3f84bb4bad 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -55,6 +55,20 @@ except ImportError: return input +def find_pruneable_heads_and_indices( + heads: List, n_heads: int, head_size: int, already_pruned_heads: set +) -> Tuple[set, "torch.LongTensor"]: + mask = torch.ones(n_heads, head_size) + heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads + for head in heads: + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in already_pruned_heads) + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index: torch.LongTensor = torch.arange(len(mask))[mask].long() + return heads, index + + class ModuleUtilsMixin: """ A few utilities for torch.nn.Modules, to be used as a mixin. diff --git a/src/transformers/modeling_xlm.py b/src/transformers/modeling_xlm.py index 44187db704..55e09cc056 100644 --- a/src/transformers/modeling_xlm.py +++ b/src/transformers/modeling_xlm.py @@ -29,7 +29,13 @@ from torch.nn import functional as F from .activations import gelu from .configuration_xlm import XLMConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead, prune_linear_layer +from .modeling_utils import ( + PreTrainedModel, + SequenceSummary, + SQuADHead, + find_pruneable_heads_and_indices, + prune_linear_layer, +) logger = logging.getLogger(__name__) @@ -105,13 +111,7 @@ class MultiHeadAttention(nn.Module): attention_head_size = self.dim // self.n_heads if len(heads) == 0: return - mask = torch.ones(self.n_heads, attention_head_size) - heads = set(heads) - self.pruned_heads - for head in heads: - head -= sum(1 if h < head else 0 for h in self.pruned_heads) - mask[head] = 0 - mask = mask.view(-1).contiguous().eq(1) - index = torch.arange(len(mask))[mask].long() + heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads) # Prune linear layers self.q_lin = prune_linear_layer(self.q_lin, index) self.k_lin = prune_linear_layer(self.k_lin, index)