[cleanup] consolidate some prune_heads logic (#4799)
This commit is contained in:
@@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
|||||||
from .configuration_albert import AlbertConfig
|
from .configuration_albert import AlbertConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer
|
from .modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -199,14 +199,9 @@ class AlbertAttention(BertSelfAttention):
|
|||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.num_attention_heads, self.attention_head_size)
|
heads, index = find_pruneable_heads_and_indices(
|
||||||
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
|
||||||
for head in heads:
|
)
|
||||||
# Compute how many pruned heads are before the head and move the index accordingly
|
|
||||||
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
|
||||||
mask[head] = 0
|
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
|
||||||
index = torch.arange(len(mask))[mask].long()
|
|
||||||
|
|
||||||
# Prune linear layers
|
# Prune linear layers
|
||||||
self.query = prune_linear_layer(self.query, index)
|
self.query = prune_linear_layer(self.query, index)
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
|||||||
from .activations import gelu, gelu_new, swish
|
from .activations import gelu, gelu_new, swish
|
||||||
from .configuration_bert import BertConfig
|
from .configuration_bert import BertConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -284,14 +284,9 @@ class BertAttention(nn.Module):
|
|||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
heads, index = find_pruneable_heads_and_indices(
|
||||||
heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads
|
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||||||
for head in heads:
|
)
|
||||||
# Compute how many pruned heads are before the head and move the index accordingly
|
|
||||||
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
|
||||||
mask[head] = 0
|
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
|
||||||
index = torch.arange(len(mask))[mask].long()
|
|
||||||
|
|
||||||
# Prune linear layers
|
# Prune linear layers
|
||||||
self.self.query = prune_linear_layer(self.self.query, index)
|
self.self.query = prune_linear_layer(self.self.query, index)
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from .activations import gelu
|
from .activations import gelu
|
||||||
from .configuration_distilbert import DistilBertConfig
|
from .configuration_distilbert import DistilBertConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -120,13 +120,7 @@ class MultiHeadSelfAttention(nn.Module):
|
|||||||
attention_head_size = self.dim // self.n_heads
|
attention_head_size = self.dim // self.n_heads
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.n_heads, attention_head_size)
|
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
|
||||||
heads = set(heads) - self.pruned_heads
|
|
||||||
for head in heads:
|
|
||||||
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
|
||||||
mask[head] = 0
|
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
|
||||||
index = torch.arange(len(mask))[mask].long()
|
|
||||||
# Prune linear layers
|
# Prune linear layers
|
||||||
self.q_lin = prune_linear_layer(self.q_lin, index)
|
self.q_lin = prune_linear_layer(self.q_lin, index)
|
||||||
self.k_lin = prune_linear_layer(self.k_lin, index)
|
self.k_lin = prune_linear_layer(self.k_lin, index)
|
||||||
|
|||||||
@@ -27,7 +27,13 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from .activations import ACT2FN
|
from .activations import ACT2FN
|
||||||
from .configuration_gpt2 import GPT2Config
|
from .configuration_gpt2 import GPT2Config
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
|
from .modeling_utils import (
|
||||||
|
Conv1D,
|
||||||
|
PreTrainedModel,
|
||||||
|
SequenceSummary,
|
||||||
|
find_pruneable_heads_and_indices,
|
||||||
|
prune_conv1d_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -122,14 +128,9 @@ class Attention(nn.Module):
|
|||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
heads, index = find_pruneable_heads_and_indices(
|
||||||
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
|
||||||
for head in heads:
|
)
|
||||||
# Compute how many pruned heads are before the head and move the index accordingly
|
|
||||||
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
|
||||||
mask[head] = 0
|
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
|
||||||
index = torch.arange(len(mask))[mask].long()
|
|
||||||
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
||||||
|
|
||||||
# Prune conv1d layers
|
# Prune conv1d layers
|
||||||
|
|||||||
@@ -29,7 +29,13 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from .activations import gelu_new, swish
|
from .activations import gelu_new, swish
|
||||||
from .configuration_openai import OpenAIGPTConfig
|
from .configuration_openai import OpenAIGPTConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
|
from .modeling_utils import (
|
||||||
|
Conv1D,
|
||||||
|
PreTrainedModel,
|
||||||
|
SequenceSummary,
|
||||||
|
find_pruneable_heads_and_indices,
|
||||||
|
prune_conv1d_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -142,13 +148,9 @@ class Attention(nn.Module):
|
|||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
heads, index = find_pruneable_heads_and_indices(
|
||||||
heads = set(heads) - self.pruned_heads
|
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
|
||||||
for head in heads:
|
)
|
||||||
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
|
||||||
mask[head] = 0
|
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
|
||||||
index = torch.arange(len(mask))[mask].long()
|
|
||||||
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
||||||
# Prune conv1d layers
|
# Prune conv1d layers
|
||||||
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
|
|||||||
|
|
||||||
from .configuration_t5 import T5Config
|
from .configuration_t5 import T5Config
|
||||||
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -216,13 +216,7 @@ class T5Attention(nn.Module):
|
|||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.n_heads, self.d_kv)
|
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, self.d_kv, self.pruned_heads)
|
||||||
heads = set(heads) - self.pruned_heads
|
|
||||||
for head in heads:
|
|
||||||
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
|
||||||
mask[head] = 0
|
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
|
||||||
index = torch.arange(len(mask))[mask].long()
|
|
||||||
# Prune linear layers
|
# Prune linear layers
|
||||||
self.q = prune_linear_layer(self.q, index)
|
self.q = prune_linear_layer(self.q, index)
|
||||||
self.k = prune_linear_layer(self.k, index)
|
self.k = prune_linear_layer(self.k, index)
|
||||||
|
|||||||
@@ -55,6 +55,20 @@ except ImportError:
|
|||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
def find_pruneable_heads_and_indices(
|
||||||
|
heads: List, n_heads: int, head_size: int, already_pruned_heads: set
|
||||||
|
) -> Tuple[set, "torch.LongTensor"]:
|
||||||
|
mask = torch.ones(n_heads, head_size)
|
||||||
|
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
|
||||||
|
for head in heads:
|
||||||
|
# Compute how many pruned heads are before the head and move the index accordingly
|
||||||
|
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
|
||||||
|
mask[head] = 0
|
||||||
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
|
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
|
||||||
|
return heads, index
|
||||||
|
|
||||||
|
|
||||||
class ModuleUtilsMixin:
|
class ModuleUtilsMixin:
|
||||||
"""
|
"""
|
||||||
A few utilities for torch.nn.Modules, to be used as a mixin.
|
A few utilities for torch.nn.Modules, to be used as a mixin.
|
||||||
|
|||||||
@@ -29,7 +29,13 @@ from torch.nn import functional as F
|
|||||||
from .activations import gelu
|
from .activations import gelu
|
||||||
from .configuration_xlm import XLMConfig
|
from .configuration_xlm import XLMConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead, prune_linear_layer
|
from .modeling_utils import (
|
||||||
|
PreTrainedModel,
|
||||||
|
SequenceSummary,
|
||||||
|
SQuADHead,
|
||||||
|
find_pruneable_heads_and_indices,
|
||||||
|
prune_linear_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -105,13 +111,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
attention_head_size = self.dim // self.n_heads
|
attention_head_size = self.dim // self.n_heads
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.n_heads, attention_head_size)
|
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
|
||||||
heads = set(heads) - self.pruned_heads
|
|
||||||
for head in heads:
|
|
||||||
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
|
||||||
mask[head] = 0
|
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
|
||||||
index = torch.arange(len(mask))[mask].long()
|
|
||||||
# Prune linear layers
|
# Prune linear layers
|
||||||
self.q_lin = prune_linear_layer(self.q_lin, index)
|
self.q_lin = prune_linear_layer(self.q_lin, index)
|
||||||
self.k_lin = prune_linear_layer(self.k_lin, index)
|
self.k_lin = prune_linear_layer(self.k_lin, index)
|
||||||
|
|||||||
Reference in New Issue
Block a user