From 34d706a0e18b42927749b78e41347dcfadc27458 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 19 Jun 2019 15:25:49 +0200 Subject: [PATCH] pruning in bertology --- README.md | 6 +- examples/bertology.py | 155 +++++++++++++++------ examples/run_classifier.py | 6 +- pytorch_pretrained_bert/modeling.py | 8 +- pytorch_pretrained_bert/modeling_gpt2.py | 5 +- pytorch_pretrained_bert/modeling_openai.py | 5 +- tests/modeling_gpt2_test.py | 6 +- tests/modeling_openai_test.py | 6 +- tests/modeling_test.py | 6 +- 9 files changed, 137 insertions(+), 66 deletions(-) diff --git a/README.md b/README.md index 287cc207e1..5d5148d529 100644 --- a/README.md +++ b/README.md @@ -724,7 +724,7 @@ We detail them here. This model takes as *inputs*: - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if some input sequence lengths are smaller than the max input sequence length of the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. - `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. -- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. +- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 0.0 => head is fully masked, 1.0 => head is not masked. This model *outputs* a tuple composed of: @@ -859,7 +859,7 @@ We detail them here. This model takes as *inputs*: - `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids You can use it to add a third type of embedding to each input token in the sequence (the previous two being the word and position embeddings). The input, position and token_type embeddings are summed inside the Transformer before the first self-attention block. -- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. +- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 0.0 => head is fully masked, 1.0 => head is not masked. This model *outputs*: - `hidden_states`: a list of all the encoded-hidden-states in the model (length of the list: number of layers + 1 for the output of the embeddings) as torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids) @@ -960,7 +960,7 @@ We detail them here. This model takes as *inputs*: You can use it to add a third type of embedding to each input token in the sequence (the previous two being the word and position embeddings). The input, position and token_type embeddings are summed inside the Transformer before the first self-attention block. - `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states (key and values in the attention blocks) to speed up sequential decoding (this is the `presents` output of the model, cf. below). -- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. +- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 0.0 => head is fully masked, 1.0 => head is not masked. This model *outputs*: - `hidden_states`: a list of all the encoded-hidden-states in the model (length of the list: number of layers + 1 for the output of the embeddings) as torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids) diff --git a/examples/bertology.py b/examples/bertology.py index f7aa4b9970..888d95e5c6 100644 --- a/examples/bertology.py +++ b/examples/bertology.py @@ -2,6 +2,7 @@ import os import argparse import logging +from datetime import timedelta, datetime from tqdm import tqdm import numpy as np @@ -35,38 +36,56 @@ def print_2d_tensor(tensor): for row in range(len(tensor)): print_1d_tensor(tensor[row], prefix=f"layer {row + 1}:\t") -def compute_heads_importance(args, model, eval_dataloader): +def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None): """ Example on how to use model outputs to compute: - head attention entropy (activated by setting output_attentions=True when we created the model - head importance scores according to http://arxiv.org/abs/1905.10650 (activated by setting keep_multihead_output=True when we created the model) """ + # Prepare our tensors + n_layers, n_heads = model.bert.config.num_hidden_layers, model.bert.config.num_attention_heads + head_importance = torch.zeros(n_layers, n_heads).to(args.device) + attn_entropy = torch.zeros(n_layers, n_heads).to(args.device) + preds = None + labels = None + tot_tokens = 0.0 + for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])): batch = tuple(t.to(args.device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch - # Do a forward pass - all_attentions, logits = model(input_ids, segment_ids, input_mask) + # Do a forward pass (not in torch.no_grad() since we need gradients for importance score - see below) + all_attentions, logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, head_mask=head_mask) - # Update head attention entropy - for layer, attn in enumerate(all_attentions): - masked_entropy = entropy(attn.detach()) * input_mask.float().unsqueeze(1) - attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach() + if compute_entropy: + # Update head attention entropy + for layer, attn in enumerate(all_attentions): + masked_entropy = entropy(attn.detach()) * input_mask.float().unsqueeze(1) + attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach() - # Update head importance scores with regards to our loss - # First backpropagate to populate the gradients - if output_mode == "classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) - elif output_mode == "regression": - loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), label_ids.view(-1)) - loss.backward() - # Second compute importance scores according to http://arxiv.org/abs/1905.10650 - multihead_outputs = model.bert.get_multihead_outputs() - for layer, mh_layer_output in enumerate(multihead_outputs): - dot = torch.einsum("bhli,bhli->bhl", [mh_layer_output.grad, mh_layer_output]) - head_importance[layer] += dot.abs().sum(-1).sum(0).detach() + if compute_importance: + # Update head importance scores with regards to our loss + # First, backpropagate to populate the gradients + if args.output_mode == "classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, args.num_labels), label_ids.view(-1)) + elif args.output_mode == "regression": + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), label_ids.view(-1)) + loss.backward() + # Second, compute importance scores according to http://arxiv.org/abs/1905.10650 + multihead_outputs = model.bert.get_multihead_outputs() + for layer, mh_layer_output in enumerate(multihead_outputs): + dot = torch.einsum("bhli,bhli->bhl", [mh_layer_output.grad, mh_layer_output]) + head_importance[layer] += dot.abs().sum(-1).sum(0).detach() + + # Also store our logits/labels if we want to compute metrics afterwards + if preds is None: + preds = logits.detach().cpu().numpy() + labels = label_ids.detach().cpu().numpy() + else: + preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) + labels = np.append(labels, label_ids.detach().cpu().numpy(), axis=0) tot_tokens += input_mask.float().detach().sum().data @@ -76,7 +95,7 @@ def compute_heads_importance(args, model, eval_dataloader): if args.normalize_importance: head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min()) - return attn_entropy, head_importance + return attn_entropy, head_importance, preds, labels def run_model(): parser = argparse.ArgumentParser() @@ -89,8 +108,11 @@ def run_model(): parser.add_argument("--normalize_importance", action='store_true', help="Whether to normalize importance score between 0 and 1") - parser.add_argument("--try_pruning", action='store_true', help="Whether to try to prune head until a threshold of accuracy.") - parser.add_argument("--pruning_threshold", default=0.9, type=float, help="Pruning threshold of accuracy.") + parser.add_argument("--try_masking", action='store_true', help="Whether to try to mask head until a threshold of accuracy.") + parser.add_argument("--masking_threshold", default=0.9, type=float, help="masking threshold in term of metrics" + "(stop masking when metric < threshold * original metric value).") + parser.add_argument("--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step.") + parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.") parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" @@ -125,9 +147,9 @@ def run_model(): # Prepare GLUE task task_name = args.task_name.lower() processor = processors[task_name]() - output_mode = output_modes[task_name] label_list = processor.get_labels() - num_labels = len(label_list) + args.output_mode = output_modes[task_name] + args.num_labels = len(label_list) # Prepare output directory if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and not args.overwrite_output_dir: @@ -145,7 +167,7 @@ def run_model(): # keep_multihead_output => will store gradient of attention head outputs for head importance computation # see: http://arxiv.org/abs/1905.10650 model = BertForSequenceClassification.from_pretrained(args.model_name_or_path, - num_labels=num_labels, + num_labels=args.num_labels, output_attentions=True, keep_multihead_output=True) if args.local_rank == 0: @@ -162,7 +184,7 @@ def run_model(): try: eval_features = torch.load(cached_eval_features_file) except: - eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) + eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, args.output_mode) if args.local_rank in [-1, 0]: logger.info("Saving eval features to cache file %s", cached_eval_features_file) torch.save(eval_features, cached_eval_features_file) @@ -170,7 +192,7 @@ def run_model(): all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) - all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long if output_mode == "classification" else torch.float) + all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long if args.output_mode == "classification" else torch.float) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) if args.data_subset > 0: @@ -183,16 +205,8 @@ def run_model(): print(args) torch.save(args, os.path.join(args.output_dir, 'run_args.bin')) - # To showcase some BERTology methods, we will compute: - # - the average entropy of each head over the dev set - # - the importance score of each head over the dev set as explained in http://arxiv.org/abs/1905.10650 - n_layers, n_heads = model.bert.config.num_hidden_layers, model.bert.config.num_attention_heads - head_importance = torch.zeros(n_layers, n_heads).to(args.device) - attn_entropy = torch.zeros(n_layers, n_heads).to(args.device) - tot_tokens = 0.0 - # Compute head entropy and importance score - attn_entropy, head_importance = compute_heads_importance(args, model, eval_dataloader) + attn_entropy, head_importance, _, _ = compute_heads_importance(args, model, eval_dataloader) # Print/save matrices np.save(os.path.join(args.output_dir, 'attn_entropy.npy'), attn_entropy) @@ -203,14 +217,67 @@ def run_model(): logger.info("Head importance scores") print_2d_tensor(head_importance) logger.info("Head ranked by importance scores") - head_ranks = torch.zeros(n_layers * n_heads, dtype=torch.long, device=args.device) + head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device) head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel()) - print_2d_tensor(head_ranks.view_as(head_importance)) + head_ranks = head_ranks.view_as(head_importance) + print_2d_tensor(head_ranks) - # Do pruning if we want to - if args.try_pruning and args.pruning_threshold > 0.0 and args.pruning_threshold < 1.0: - - + # Do masking if we want to + if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0: + _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False) + preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds) + original_score = compute_metrics(task_name, preds, labels)[args.metric_name] + logger.info("Pruning: original score: %f", original_score) + + new_head_mask = torch.ones_like(head_importance) + num_to_mask = int(new_head_mask.numel() * args.masking_amount) + + current_score = original_score + while current_score >= original_score * args.masking_threshold: + head_mask = new_head_mask + # heads from most important to least + heads_to_mask = head_importance.view(-1).sort(descending=True)[1] + # keep only not-masked heads + heads_to_mask = heads_to_mask[head_mask.view(-1).nonzero()][:, 0] + + if len(heads_to_mask) <= num_to_mask: + break + + # mask heads + heads_to_mask = heads_to_mask[-num_to_mask:] + new_head_mask = head_mask.view(-1) + new_head_mask[heads_to_mask] = 0.0 + new_head_mask = new_head_mask.view_as(head_importance) + print_2d_tensor(new_head_mask) + + # Compute metric and head importance again + _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask) + preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds) + current_score = compute_metrics(task_name, preds, labels)[args.metric_name] + logger.info("Masking: current score: %f, remaning heads %.1f percents", current_score, head_mask.sum()/head_mask.numel() * 100) + + # Try pruning and test time speedup + # Pruning is like masking but we actually remove the masked weights + before_time = datetime.now() + _, _, preds, labels = compute_heads_importance(args, model, eval_dataloader, + compute_entropy=False, compute_importance=False, head_mask=head_mask) + preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds) + score_masking = compute_metrics(task_name, preds, labels)[args.metric_name] + original_time = datetime.now() - before_time + + heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask))) + assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item() + model.bert.prune_heads(heads_to_prune) + + before_time = datetime.now() + _, _, preds, labels = compute_heads_importance(args, model, eval_dataloader, + compute_entropy=False, compute_importance=False, head_mask=None) + preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds) + score_pruning = compute_metrics(task_name, preds, labels)[args.metric_name] + new_time = datetime.now() - before_time + + logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning) + logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time/new_time * 100) if __name__ == '__main__': run_model() diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 0885d7b145..5a359ad262 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -308,7 +308,7 @@ def main(): input_ids, input_mask, segment_ids, label_ids = batch # define a new function to compute loss values for both output_modes - logits = model(input_ids, segment_ids, input_mask) + logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask) if output_mode == "classification": loss_fct = CrossEntropyLoss() @@ -422,7 +422,7 @@ def main(): label_ids = label_ids.to(device) with torch.no_grad(): - logits = model(input_ids, segment_ids, input_mask) + logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask) # create eval loss and other metric required by the task if output_mode == "classification": @@ -503,7 +503,7 @@ def main(): label_ids = label_ids.to(device) with torch.no_grad(): - logits = model(input_ids, segment_ids, input_mask, labels=None) + logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None) loss_fct = CrossEntropyLoss() tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index f5156d7d95..409956a141 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -408,6 +408,8 @@ class BertAttention(nn.Module): self.output = BertSelfOutput(config) def prune_heads(self, heads): + if len(heads) == 0: + return mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) for head in heads: mask[head] = 0 @@ -858,9 +860,10 @@ class BertModel(BertPreTrainedModel): extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 # Prepare head mask if needed - # 1.0 in head_mask indicate we mask the head + # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape num_hidden_layers x batch x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] if head_mask is not None: if head_mask.dim() == 1: head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) @@ -868,7 +871,6 @@ class BertModel(BertPreTrainedModel): elif head_mask.dim() == 2: head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility - head_mask = (1.0 - head_mask) else: head_mask = [None] * self.config.num_hidden_layers diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index dd195fc880..0a51af8118 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -264,6 +264,8 @@ class Attention(nn.Module): self.resid_dropout = nn.Dropout(config.resid_pdrop) def prune_heads(self, heads): + if len(heads) == 0: + return mask = torch.ones(self.n_head, self.split_size // self.n_head) for head in heads: mask[head] = 0 @@ -714,7 +716,7 @@ class GPT2Model(GPT2PreTrainedModel): position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # Prepare head mask if needed - # 1.0 in head_mask indicate we mask the head + # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # head_mask has shape n_layer x batch x n_heads x N x N if head_mask is not None: @@ -724,7 +726,6 @@ class GPT2Model(GPT2PreTrainedModel): elif head_mask.dim() == 2: head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility - head_mask = (1.0 - head_mask) else: head_mask = [None] * self.config.n_layer diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 91848f3c68..f02f016ef1 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -274,6 +274,8 @@ class Attention(nn.Module): self.resid_dropout = nn.Dropout(config.resid_pdrop) def prune_heads(self, heads): + if len(heads) == 0: + return mask = torch.ones(self.n_head, self.split_size // self.n_head) for head in heads: mask[head] = 0 @@ -710,7 +712,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # Prepare head mask if needed - # 1.0 in head_mask indicate we mask the head + # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # head_mask has shape n_layer x batch x n_heads x N x N if head_mask is not None: @@ -720,7 +722,6 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): elif head_mask.dim() == 2: head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility - head_mask = (1.0 - head_mask) else: head_mask = [None] * self.config.n_layer diff --git a/tests/modeling_gpt2_test.py b/tests/modeling_gpt2_test.py index 122b9c7913..589de22f5d 100644 --- a/tests/modeling_gpt2_test.py +++ b/tests/modeling_gpt2_test.py @@ -215,9 +215,9 @@ class GPT2ModelTest(unittest.TestCase): for model_class in (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel): model = model_class(config=config, keep_multihead_output=True) model.eval() - head_mask = torch.zeros(self.n_layer, self.n_head).to(input_ids.device) - head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer - head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer + head_mask = torch.ones(self.n_layer, self.n_head).to(input_ids.device) + head_mask[0, 1:-1] = 0.0 # Mask all but the first and last heads on the first layer + head_mask[-1, 1:] = 0.0 # Mask all but the first head on the last layer if isinstance(model, GPT2DoubleHeadsModel): output = model(input_ids, mc_token_ids, head_mask=head_mask) else: diff --git a/tests/modeling_openai_test.py b/tests/modeling_openai_test.py index adb3671ef5..c8fc8f48fc 100644 --- a/tests/modeling_openai_test.py +++ b/tests/modeling_openai_test.py @@ -188,9 +188,9 @@ class OpenAIGPTModelTest(unittest.TestCase): for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel): model = model_class(config=config, keep_multihead_output=True) model.eval() - head_mask = torch.zeros(self.n_layer, self.n_head).to(input_ids.device) - head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer - head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer + head_mask = torch.ones(self.n_layer, self.n_head).to(input_ids.device) + head_mask[0, 1:-1] = 0.0 # Mask all but the first and last heads on the first layer + head_mask[-1, 1:] = 0.0 # Mask all but the first head on the last layer if isinstance(model, OpenAIGPTDoubleHeadsModel): output = model(input_ids, mc_token_ids, head_mask=head_mask) else: diff --git a/tests/modeling_test.py b/tests/modeling_test.py index 80bb3d3c95..126c6fad13 100644 --- a/tests/modeling_test.py +++ b/tests/modeling_test.py @@ -305,9 +305,9 @@ class BertModelTest(unittest.TestCase): else: model = model_class(config=config, keep_multihead_output=True) model.eval() - head_mask = torch.zeros(self.num_hidden_layers, self.num_attention_heads).to(input_ids.device) - head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer - head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer + head_mask = torch.ones(self.num_hidden_layers, self.num_attention_heads).to(input_ids.device) + head_mask[0, 1:-1] = 0.0 # Mask all but the first and last heads on the first layer + head_mask[-1, 1:] = 0.0 # Mask all but the first head on the last layer output = model(input_ids, token_type_ids, input_mask, head_mask=head_mask) if isinstance(model, BertModel):