pruning in bertology
This commit is contained in:
@@ -724,7 +724,7 @@ We detail them here. This model takes as *inputs*:
|
|||||||
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
||||||
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if some input sequence lengths are smaller than the max input sequence length of the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
|
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if some input sequence lengths are smaller than the max input sequence length of the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
|
||||||
- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
|
- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
|
||||||
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 0.0 => head is fully masked, 1.0 => head is not masked.
|
||||||
|
|
||||||
This model *outputs* a tuple composed of:
|
This model *outputs* a tuple composed of:
|
||||||
|
|
||||||
@@ -859,7 +859,7 @@ We detail them here. This model takes as *inputs*:
|
|||||||
- `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
- `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||||
You can use it to add a third type of embedding to each input token in the sequence
|
You can use it to add a third type of embedding to each input token in the sequence
|
||||||
(the previous two being the word and position embeddings). The input, position and token_type embeddings are summed inside the Transformer before the first self-attention block.
|
(the previous two being the word and position embeddings). The input, position and token_type embeddings are summed inside the Transformer before the first self-attention block.
|
||||||
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 0.0 => head is fully masked, 1.0 => head is not masked.
|
||||||
|
|
||||||
This model *outputs*:
|
This model *outputs*:
|
||||||
- `hidden_states`: a list of all the encoded-hidden-states in the model (length of the list: number of layers + 1 for the output of the embeddings) as torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
|
- `hidden_states`: a list of all the encoded-hidden-states in the model (length of the list: number of layers + 1 for the output of the embeddings) as torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
|
||||||
@@ -960,7 +960,7 @@ We detail them here. This model takes as *inputs*:
|
|||||||
You can use it to add a third type of embedding to each input token in the sequence
|
You can use it to add a third type of embedding to each input token in the sequence
|
||||||
(the previous two being the word and position embeddings). The input, position and token_type embeddings are summed inside the Transformer before the first self-attention block.
|
(the previous two being the word and position embeddings). The input, position and token_type embeddings are summed inside the Transformer before the first self-attention block.
|
||||||
- `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states (key and values in the attention blocks) to speed up sequential decoding (this is the `presents` output of the model, cf. below).
|
- `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states (key and values in the attention blocks) to speed up sequential decoding (this is the `presents` output of the model, cf. below).
|
||||||
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 0.0 => head is fully masked, 1.0 => head is not masked.
|
||||||
|
|
||||||
This model *outputs*:
|
This model *outputs*:
|
||||||
- `hidden_states`: a list of all the encoded-hidden-states in the model (length of the list: number of layers + 1 for the output of the embeddings) as torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
|
- `hidden_states`: a list of all the encoded-hidden-states in the model (length of the list: number of layers + 1 for the output of the embeddings) as torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import timedelta, datetime
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -35,38 +36,56 @@ def print_2d_tensor(tensor):
|
|||||||
for row in range(len(tensor)):
|
for row in range(len(tensor)):
|
||||||
print_1d_tensor(tensor[row], prefix=f"layer {row + 1}:\t")
|
print_1d_tensor(tensor[row], prefix=f"layer {row + 1}:\t")
|
||||||
|
|
||||||
def compute_heads_importance(args, model, eval_dataloader):
|
def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None):
|
||||||
""" Example on how to use model outputs to compute:
|
""" Example on how to use model outputs to compute:
|
||||||
- head attention entropy (activated by setting output_attentions=True when we created the model
|
- head attention entropy (activated by setting output_attentions=True when we created the model
|
||||||
- head importance scores according to http://arxiv.org/abs/1905.10650
|
- head importance scores according to http://arxiv.org/abs/1905.10650
|
||||||
(activated by setting keep_multihead_output=True when we created the model)
|
(activated by setting keep_multihead_output=True when we created the model)
|
||||||
"""
|
"""
|
||||||
|
# Prepare our tensors
|
||||||
|
n_layers, n_heads = model.bert.config.num_hidden_layers, model.bert.config.num_attention_heads
|
||||||
|
head_importance = torch.zeros(n_layers, n_heads).to(args.device)
|
||||||
|
attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
|
||||||
|
preds = None
|
||||||
|
labels = None
|
||||||
|
tot_tokens = 0.0
|
||||||
|
|
||||||
for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
input_ids, input_mask, segment_ids, label_ids = batch
|
input_ids, input_mask, segment_ids, label_ids = batch
|
||||||
|
|
||||||
# Do a forward pass
|
# Do a forward pass (not in torch.no_grad() since we need gradients for importance score - see below)
|
||||||
all_attentions, logits = model(input_ids, segment_ids, input_mask)
|
all_attentions, logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, head_mask=head_mask)
|
||||||
|
|
||||||
# Update head attention entropy
|
if compute_entropy:
|
||||||
for layer, attn in enumerate(all_attentions):
|
# Update head attention entropy
|
||||||
masked_entropy = entropy(attn.detach()) * input_mask.float().unsqueeze(1)
|
for layer, attn in enumerate(all_attentions):
|
||||||
attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()
|
masked_entropy = entropy(attn.detach()) * input_mask.float().unsqueeze(1)
|
||||||
|
attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()
|
||||||
|
|
||||||
# Update head importance scores with regards to our loss
|
if compute_importance:
|
||||||
# First backpropagate to populate the gradients
|
# Update head importance scores with regards to our loss
|
||||||
if output_mode == "classification":
|
# First, backpropagate to populate the gradients
|
||||||
loss_fct = CrossEntropyLoss()
|
if args.output_mode == "classification":
|
||||||
loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
loss_fct = CrossEntropyLoss()
|
||||||
elif output_mode == "regression":
|
loss = loss_fct(logits.view(-1, args.num_labels), label_ids.view(-1))
|
||||||
loss_fct = MSELoss()
|
elif args.output_mode == "regression":
|
||||||
loss = loss_fct(logits.view(-1), label_ids.view(-1))
|
loss_fct = MSELoss()
|
||||||
loss.backward()
|
loss = loss_fct(logits.view(-1), label_ids.view(-1))
|
||||||
# Second compute importance scores according to http://arxiv.org/abs/1905.10650
|
loss.backward()
|
||||||
multihead_outputs = model.bert.get_multihead_outputs()
|
# Second, compute importance scores according to http://arxiv.org/abs/1905.10650
|
||||||
for layer, mh_layer_output in enumerate(multihead_outputs):
|
multihead_outputs = model.bert.get_multihead_outputs()
|
||||||
dot = torch.einsum("bhli,bhli->bhl", [mh_layer_output.grad, mh_layer_output])
|
for layer, mh_layer_output in enumerate(multihead_outputs):
|
||||||
head_importance[layer] += dot.abs().sum(-1).sum(0).detach()
|
dot = torch.einsum("bhli,bhli->bhl", [mh_layer_output.grad, mh_layer_output])
|
||||||
|
head_importance[layer] += dot.abs().sum(-1).sum(0).detach()
|
||||||
|
|
||||||
|
# Also store our logits/labels if we want to compute metrics afterwards
|
||||||
|
if preds is None:
|
||||||
|
preds = logits.detach().cpu().numpy()
|
||||||
|
labels = label_ids.detach().cpu().numpy()
|
||||||
|
else:
|
||||||
|
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||||
|
labels = np.append(labels, label_ids.detach().cpu().numpy(), axis=0)
|
||||||
|
|
||||||
tot_tokens += input_mask.float().detach().sum().data
|
tot_tokens += input_mask.float().detach().sum().data
|
||||||
|
|
||||||
@@ -76,7 +95,7 @@ def compute_heads_importance(args, model, eval_dataloader):
|
|||||||
if args.normalize_importance:
|
if args.normalize_importance:
|
||||||
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
|
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
|
||||||
|
|
||||||
return attn_entropy, head_importance
|
return attn_entropy, head_importance, preds, labels
|
||||||
|
|
||||||
def run_model():
|
def run_model():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@@ -89,8 +108,11 @@ def run_model():
|
|||||||
|
|
||||||
parser.add_argument("--normalize_importance", action='store_true', help="Whether to normalize importance score between 0 and 1")
|
parser.add_argument("--normalize_importance", action='store_true', help="Whether to normalize importance score between 0 and 1")
|
||||||
|
|
||||||
parser.add_argument("--try_pruning", action='store_true', help="Whether to try to prune head until a threshold of accuracy.")
|
parser.add_argument("--try_masking", action='store_true', help="Whether to try to mask head until a threshold of accuracy.")
|
||||||
parser.add_argument("--pruning_threshold", default=0.9, type=float, help="Pruning threshold of accuracy.")
|
parser.add_argument("--masking_threshold", default=0.9, type=float, help="masking threshold in term of metrics"
|
||||||
|
"(stop masking when metric < threshold * original metric value).")
|
||||||
|
parser.add_argument("--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step.")
|
||||||
|
parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")
|
||||||
|
|
||||||
parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n"
|
parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||||
"Sequences longer than this will be truncated, and sequences shorter \n"
|
"Sequences longer than this will be truncated, and sequences shorter \n"
|
||||||
@@ -125,9 +147,9 @@ def run_model():
|
|||||||
# Prepare GLUE task
|
# Prepare GLUE task
|
||||||
task_name = args.task_name.lower()
|
task_name = args.task_name.lower()
|
||||||
processor = processors[task_name]()
|
processor = processors[task_name]()
|
||||||
output_mode = output_modes[task_name]
|
|
||||||
label_list = processor.get_labels()
|
label_list = processor.get_labels()
|
||||||
num_labels = len(label_list)
|
args.output_mode = output_modes[task_name]
|
||||||
|
args.num_labels = len(label_list)
|
||||||
|
|
||||||
# Prepare output directory
|
# Prepare output directory
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and not args.overwrite_output_dir:
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and not args.overwrite_output_dir:
|
||||||
@@ -145,7 +167,7 @@ def run_model():
|
|||||||
# keep_multihead_output => will store gradient of attention head outputs for head importance computation
|
# keep_multihead_output => will store gradient of attention head outputs for head importance computation
|
||||||
# see: http://arxiv.org/abs/1905.10650
|
# see: http://arxiv.org/abs/1905.10650
|
||||||
model = BertForSequenceClassification.from_pretrained(args.model_name_or_path,
|
model = BertForSequenceClassification.from_pretrained(args.model_name_or_path,
|
||||||
num_labels=num_labels,
|
num_labels=args.num_labels,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
keep_multihead_output=True)
|
keep_multihead_output=True)
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
@@ -162,7 +184,7 @@ def run_model():
|
|||||||
try:
|
try:
|
||||||
eval_features = torch.load(cached_eval_features_file)
|
eval_features = torch.load(cached_eval_features_file)
|
||||||
except:
|
except:
|
||||||
eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, args.output_mode)
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
logger.info("Saving eval features to cache file %s", cached_eval_features_file)
|
logger.info("Saving eval features to cache file %s", cached_eval_features_file)
|
||||||
torch.save(eval_features, cached_eval_features_file)
|
torch.save(eval_features, cached_eval_features_file)
|
||||||
@@ -170,7 +192,7 @@ def run_model():
|
|||||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long if output_mode == "classification" else torch.float)
|
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long if args.output_mode == "classification" else torch.float)
|
||||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||||
|
|
||||||
if args.data_subset > 0:
|
if args.data_subset > 0:
|
||||||
@@ -183,16 +205,8 @@ def run_model():
|
|||||||
print(args)
|
print(args)
|
||||||
torch.save(args, os.path.join(args.output_dir, 'run_args.bin'))
|
torch.save(args, os.path.join(args.output_dir, 'run_args.bin'))
|
||||||
|
|
||||||
# To showcase some BERTology methods, we will compute:
|
|
||||||
# - the average entropy of each head over the dev set
|
|
||||||
# - the importance score of each head over the dev set as explained in http://arxiv.org/abs/1905.10650
|
|
||||||
n_layers, n_heads = model.bert.config.num_hidden_layers, model.bert.config.num_attention_heads
|
|
||||||
head_importance = torch.zeros(n_layers, n_heads).to(args.device)
|
|
||||||
attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
|
|
||||||
tot_tokens = 0.0
|
|
||||||
|
|
||||||
# Compute head entropy and importance score
|
# Compute head entropy and importance score
|
||||||
attn_entropy, head_importance = compute_heads_importance(args, model, eval_dataloader)
|
attn_entropy, head_importance, _, _ = compute_heads_importance(args, model, eval_dataloader)
|
||||||
|
|
||||||
# Print/save matrices
|
# Print/save matrices
|
||||||
np.save(os.path.join(args.output_dir, 'attn_entropy.npy'), attn_entropy)
|
np.save(os.path.join(args.output_dir, 'attn_entropy.npy'), attn_entropy)
|
||||||
@@ -203,14 +217,67 @@ def run_model():
|
|||||||
logger.info("Head importance scores")
|
logger.info("Head importance scores")
|
||||||
print_2d_tensor(head_importance)
|
print_2d_tensor(head_importance)
|
||||||
logger.info("Head ranked by importance scores")
|
logger.info("Head ranked by importance scores")
|
||||||
head_ranks = torch.zeros(n_layers * n_heads, dtype=torch.long, device=args.device)
|
head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
|
||||||
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel())
|
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel())
|
||||||
print_2d_tensor(head_ranks.view_as(head_importance))
|
head_ranks = head_ranks.view_as(head_importance)
|
||||||
|
print_2d_tensor(head_ranks)
|
||||||
|
|
||||||
# Do pruning if we want to
|
# Do masking if we want to
|
||||||
if args.try_pruning and args.pruning_threshold > 0.0 and args.pruning_threshold < 1.0:
|
if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
|
||||||
|
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
|
||||||
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||||
|
original_score = compute_metrics(task_name, preds, labels)[args.metric_name]
|
||||||
|
logger.info("Pruning: original score: %f", original_score)
|
||||||
|
|
||||||
|
new_head_mask = torch.ones_like(head_importance)
|
||||||
|
num_to_mask = int(new_head_mask.numel() * args.masking_amount)
|
||||||
|
|
||||||
|
current_score = original_score
|
||||||
|
while current_score >= original_score * args.masking_threshold:
|
||||||
|
head_mask = new_head_mask
|
||||||
|
# heads from most important to least
|
||||||
|
heads_to_mask = head_importance.view(-1).sort(descending=True)[1]
|
||||||
|
# keep only not-masked heads
|
||||||
|
heads_to_mask = heads_to_mask[head_mask.view(-1).nonzero()][:, 0]
|
||||||
|
|
||||||
|
if len(heads_to_mask) <= num_to_mask:
|
||||||
|
break
|
||||||
|
|
||||||
|
# mask heads
|
||||||
|
heads_to_mask = heads_to_mask[-num_to_mask:]
|
||||||
|
new_head_mask = head_mask.view(-1)
|
||||||
|
new_head_mask[heads_to_mask] = 0.0
|
||||||
|
new_head_mask = new_head_mask.view_as(head_importance)
|
||||||
|
print_2d_tensor(new_head_mask)
|
||||||
|
|
||||||
|
# Compute metric and head importance again
|
||||||
|
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask)
|
||||||
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||||
|
current_score = compute_metrics(task_name, preds, labels)[args.metric_name]
|
||||||
|
logger.info("Masking: current score: %f, remaning heads %.1f percents", current_score, head_mask.sum()/head_mask.numel() * 100)
|
||||||
|
|
||||||
|
# Try pruning and test time speedup
|
||||||
|
# Pruning is like masking but we actually remove the masked weights
|
||||||
|
before_time = datetime.now()
|
||||||
|
_, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
|
||||||
|
compute_entropy=False, compute_importance=False, head_mask=head_mask)
|
||||||
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||||
|
score_masking = compute_metrics(task_name, preds, labels)[args.metric_name]
|
||||||
|
original_time = datetime.now() - before_time
|
||||||
|
|
||||||
|
heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask)))
|
||||||
|
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
|
||||||
|
model.bert.prune_heads(heads_to_prune)
|
||||||
|
|
||||||
|
before_time = datetime.now()
|
||||||
|
_, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
|
||||||
|
compute_entropy=False, compute_importance=False, head_mask=None)
|
||||||
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||||
|
score_pruning = compute_metrics(task_name, preds, labels)[args.metric_name]
|
||||||
|
new_time = datetime.now() - before_time
|
||||||
|
|
||||||
|
logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
|
||||||
|
logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time/new_time * 100)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_model()
|
run_model()
|
||||||
|
|||||||
@@ -308,7 +308,7 @@ def main():
|
|||||||
input_ids, input_mask, segment_ids, label_ids = batch
|
input_ids, input_mask, segment_ids, label_ids = batch
|
||||||
|
|
||||||
# define a new function to compute loss values for both output_modes
|
# define a new function to compute loss values for both output_modes
|
||||||
logits = model(input_ids, segment_ids, input_mask)
|
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
|
||||||
|
|
||||||
if output_mode == "classification":
|
if output_mode == "classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
@@ -422,7 +422,7 @@ def main():
|
|||||||
label_ids = label_ids.to(device)
|
label_ids = label_ids.to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(input_ids, segment_ids, input_mask)
|
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
|
||||||
|
|
||||||
# create eval loss and other metric required by the task
|
# create eval loss and other metric required by the task
|
||||||
if output_mode == "classification":
|
if output_mode == "classification":
|
||||||
@@ -503,7 +503,7 @@ def main():
|
|||||||
label_ids = label_ids.to(device)
|
label_ids = label_ids.to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(input_ids, segment_ids, input_mask, labels=None)
|
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None)
|
||||||
|
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
||||||
|
|||||||
@@ -408,6 +408,8 @@ class BertAttention(nn.Module):
|
|||||||
self.output = BertSelfOutput(config)
|
self.output = BertSelfOutput(config)
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
|
if len(heads) == 0:
|
||||||
|
return
|
||||||
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
||||||
for head in heads:
|
for head in heads:
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
@@ -858,9 +860,10 @@ class BertModel(BertPreTrainedModel):
|
|||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we mask the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
# head_mask has shape num_hidden_layers x batch x n_heads x N x N
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
if head_mask is not None:
|
if head_mask is not None:
|
||||||
if head_mask.dim() == 1:
|
if head_mask.dim() == 1:
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||||
@@ -868,7 +871,6 @@ class BertModel(BertPreTrainedModel):
|
|||||||
elif head_mask.dim() == 2:
|
elif head_mask.dim() == 2:
|
||||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||||
head_mask = (1.0 - head_mask)
|
|
||||||
else:
|
else:
|
||||||
head_mask = [None] * self.config.num_hidden_layers
|
head_mask = [None] * self.config.num_hidden_layers
|
||||||
|
|
||||||
|
|||||||
@@ -264,6 +264,8 @@ class Attention(nn.Module):
|
|||||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
|
if len(heads) == 0:
|
||||||
|
return
|
||||||
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
||||||
for head in heads:
|
for head in heads:
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
@@ -714,7 +716,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we mask the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||||
if head_mask is not None:
|
if head_mask is not None:
|
||||||
@@ -724,7 +726,6 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
elif head_mask.dim() == 2:
|
elif head_mask.dim() == 2:
|
||||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||||
head_mask = (1.0 - head_mask)
|
|
||||||
else:
|
else:
|
||||||
head_mask = [None] * self.config.n_layer
|
head_mask = [None] * self.config.n_layer
|
||||||
|
|
||||||
|
|||||||
@@ -274,6 +274,8 @@ class Attention(nn.Module):
|
|||||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
|
if len(heads) == 0:
|
||||||
|
return
|
||||||
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
||||||
for head in heads:
|
for head in heads:
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
@@ -710,7 +712,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we mask the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||||
if head_mask is not None:
|
if head_mask is not None:
|
||||||
@@ -720,7 +722,6 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
elif head_mask.dim() == 2:
|
elif head_mask.dim() == 2:
|
||||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||||
head_mask = (1.0 - head_mask)
|
|
||||||
else:
|
else:
|
||||||
head_mask = [None] * self.config.n_layer
|
head_mask = [None] * self.config.n_layer
|
||||||
|
|
||||||
|
|||||||
@@ -215,9 +215,9 @@ class GPT2ModelTest(unittest.TestCase):
|
|||||||
for model_class in (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel):
|
for model_class in (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel):
|
||||||
model = model_class(config=config, keep_multihead_output=True)
|
model = model_class(config=config, keep_multihead_output=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
head_mask = torch.zeros(self.n_layer, self.n_head).to(input_ids.device)
|
head_mask = torch.ones(self.n_layer, self.n_head).to(input_ids.device)
|
||||||
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
|
head_mask[0, 1:-1] = 0.0 # Mask all but the first and last heads on the first layer
|
||||||
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
|
head_mask[-1, 1:] = 0.0 # Mask all but the first head on the last layer
|
||||||
if isinstance(model, GPT2DoubleHeadsModel):
|
if isinstance(model, GPT2DoubleHeadsModel):
|
||||||
output = model(input_ids, mc_token_ids, head_mask=head_mask)
|
output = model(input_ids, mc_token_ids, head_mask=head_mask)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -188,9 +188,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
|||||||
for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel):
|
for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel):
|
||||||
model = model_class(config=config, keep_multihead_output=True)
|
model = model_class(config=config, keep_multihead_output=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
head_mask = torch.zeros(self.n_layer, self.n_head).to(input_ids.device)
|
head_mask = torch.ones(self.n_layer, self.n_head).to(input_ids.device)
|
||||||
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
|
head_mask[0, 1:-1] = 0.0 # Mask all but the first and last heads on the first layer
|
||||||
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
|
head_mask[-1, 1:] = 0.0 # Mask all but the first head on the last layer
|
||||||
if isinstance(model, OpenAIGPTDoubleHeadsModel):
|
if isinstance(model, OpenAIGPTDoubleHeadsModel):
|
||||||
output = model(input_ids, mc_token_ids, head_mask=head_mask)
|
output = model(input_ids, mc_token_ids, head_mask=head_mask)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -305,9 +305,9 @@ class BertModelTest(unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
model = model_class(config=config, keep_multihead_output=True)
|
model = model_class(config=config, keep_multihead_output=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
head_mask = torch.zeros(self.num_hidden_layers, self.num_attention_heads).to(input_ids.device)
|
head_mask = torch.ones(self.num_hidden_layers, self.num_attention_heads).to(input_ids.device)
|
||||||
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
|
head_mask[0, 1:-1] = 0.0 # Mask all but the first and last heads on the first layer
|
||||||
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
|
head_mask[-1, 1:] = 0.0 # Mask all but the first head on the last layer
|
||||||
output = model(input_ids, token_type_ids, input_mask, head_mask=head_mask)
|
output = model(input_ids, token_type_ids, input_mask, head_mask=head_mask)
|
||||||
|
|
||||||
if isinstance(model, BertModel):
|
if isinstance(model, BertModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user