From 34858ae1d9e11dc51100b26ac468770c81c8afc1 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 17 Jun 2019 11:02:39 +0200 Subject: [PATCH] adding bert whole words, bertgerman and gpt-2 medium models, head masking --- README.md | 7 ++- pytorch_pretrained_bert/modeling.py | 71 ++++++++++++++++--------- pytorch_pretrained_bert/tokenization.py | 4 ++ 3 files changed, 56 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 1d734f9df9..b8a4a9d5a4 100644 --- a/README.md +++ b/README.md @@ -492,9 +492,12 @@ where - `bert-base-multilingual-cased`: **(New, recommended)** 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters - `bert-base-chinese`: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters - `bert-base-german-cased`: Trained on German data only, 12-layer, 768-hidden, 12-heads, 110M parameters [Performance Evaluation](https://deepset.ai/german-bert) - - `openai-gpt`: OpenAI English model, 12-layer, 768-hidden, 12-heads, 110M parameters - - `transfo-xl-wt103`: Transformer-XL English model trained on wikitext-103, 18-layer, 1024-hidden, 16-heads, 257M parameters + - `bert-large-uncased-whole-word-masking`: 24-layer, 1024-hidden, 16-heads, 340M parameters - Trained with Whole Word Masking (mask all of the the tokens corresponding to a word at once) + - `bert-large-cased-whole-word-masking`: 24-layer, 1024-hidden, 16-heads, 340M parameters - Trained with Whole Word Masking (mask all of the the tokens corresponding to a word at once) + - `openai-gpt`: OpenAI GPT English model, 12-layer, 768-hidden, 12-heads, 110M parameters - `gpt2`: OpenAI GPT-2 English model, 12-layer, 768-hidden, 12-heads, 117M parameters + - `gpt2-medium`: OpenAI GPT-2 English model, 24-layer, 1024-hidden, 16-heads, 345M parameters + - `transfo-xl-wt103`: Transformer-XL English model trained on wikitext-103, 18-layer, 1024-hidden, 16-heads, 257M parameters - a path or url to a pretrained model archive containing: diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 3006d8e971..11a7191df5 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -45,6 +45,8 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased.tar.gz", + 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking.tar.gz", + 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking.tar.gz", } BERT_CONFIG_NAME = 'bert_config.json' TF_WEIGHTS_NAME = 'model.ckpt' @@ -279,13 +281,16 @@ class BertEmbeddings(nn.Module): class BertSelfAttention(nn.Module): - def __init__(self, config, output_attentions=False): + def __init__(self, config, output_attentions=False, keep_multihead_output=False): super(BertSelfAttention, self).__init__() if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) self.output_attentions = output_attentions + self.keep_multihead_output = keep_multihead_output + self.multihead_output = None + self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size @@ -301,7 +306,7 @@ class BertSelfAttention(nn.Module): x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask): + def forward(self, hidden_states, attention_mask, head_mask=None): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) @@ -323,7 +328,20 @@ class BertSelfAttention(nn.Module): # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) + # Mask heads if we want to + # attention_probs has shape bsz x n_heads x N x N + if head_mask is not None: + if head_mask.dim() == 1: + head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + elif head_mask.dim() == 2: + head_mask.unsqueeze(-1).unsqueeze(-1) # We can define heads to mask for each instance in the batch + attention_probs = attention_probs * head_mask + context_layer = torch.matmul(attention_probs, value_layer) + if self.keep_multihead_output: + self.multihead_output = context_layer + self.multihead_output.retain_grad() + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) @@ -353,8 +371,8 @@ class BertAttention(nn.Module): self.self = BertSelfAttention(config, output_attentions=output_attentions) self.output = BertSelfOutput(config) - def forward(self, input_tensor, attention_mask): - self_output = self.self(input_tensor, attention_mask) + def forward(self, input_tensor, attention_mask, head_mask=None): + self_output = self.self(input_tensor, attention_mask, head_mask) if self.output_attentions: attentions, self_output = self_output attention_output = self.output(self_output, input_tensor) @@ -400,8 +418,8 @@ class BertLayer(nn.Module): self.intermediate = BertIntermediate(config) self.output = BertOutput(config) - def forward(self, hidden_states, attention_mask): - attention_output = self.attention(hidden_states, attention_mask) + def forward(self, hidden_states, attention_mask, head_mask=None): + attention_output = self.attention(hidden_states, attention_mask, head_mask) if self.output_attentions: attentions, attention_output = attention_output intermediate_output = self.intermediate(attention_output) @@ -418,11 +436,11 @@ class BertEncoder(nn.Module): layer = BertLayer(config, output_attentions=output_attentions) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) - def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): + def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, head_mask=None): all_encoder_layers = [] all_attentions = [] for layer_module in self.layer: - hidden_states = layer_module(hidden_states, attention_mask) + hidden_states = layer_module(hidden_states, attention_mask, head_mask) if self.output_attentions: attentions, hidden_states = hidden_states all_attentions.append(attentions) @@ -731,7 +749,7 @@ class BertModel(BertPreTrainedModel): self.pooler = BertPooler(config) self.apply(self.init_bert_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): + def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, head_mask=None): if attention_mask is None: attention_mask = torch.ones_like(input_ids) if token_type_ids is None: @@ -755,7 +773,8 @@ class BertModel(BertPreTrainedModel): embedding_output = self.embeddings(input_ids, token_type_ids) encoded_layers = self.encoder(embedding_output, extended_attention_mask, - output_all_encoded_layers=output_all_encoded_layers) + output_all_encoded_layers=output_all_encoded_layers, + head_mask=head_mask) if self.output_attentions: all_attentions, encoded_layers = encoded_layers sequence_output = encoded_layers[-1] @@ -824,9 +843,9 @@ class BertForPreTraining(BertPreTrainedModel): self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) self.apply(self.init_bert_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None): + def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, head_mask=None): outputs = self.bert(input_ids, token_type_ids, attention_mask, - output_all_encoded_layers=False) + output_all_encoded_layers=False, head_mask=head_mask) if self.output_attentions: all_attentions, sequence_output, pooled_output = outputs else: @@ -893,9 +912,10 @@ class BertForMaskedLM(BertPreTrainedModel): self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) self.apply(self.init_bert_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None): + def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None): outputs = self.bert(input_ids, token_type_ids, attention_mask, - output_all_encoded_layers=False) + output_all_encoded_layers=False, + head_mask=head_mask) if self.output_attentions: all_attentions, sequence_output, _ = outputs else: @@ -961,9 +981,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel): self.cls = BertOnlyNSPHead(config) self.apply(self.init_bert_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None): + def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None): outputs = self.bert(input_ids, token_type_ids, attention_mask, - output_all_encoded_layers=False) + output_all_encoded_layers=False, + head_mask=head_mask) if self.output_attentions: all_attentions, _, pooled_output = outputs else: @@ -1033,8 +1054,8 @@ class BertForSequenceClassification(BertPreTrainedModel): self.classifier = nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): - outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): + outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask) if self.output_attentions: all_attentions, _, pooled_output = outputs else: @@ -1104,11 +1125,11 @@ class BertForMultipleChoice(BertPreTrainedModel): self.classifier = nn.Linear(config.hidden_size, 1) self.apply(self.init_bert_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None - outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) + outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False, head_mask=head_mask) if self.output_attentions: all_attentions, _, pooled_output = outputs else: @@ -1180,8 +1201,8 @@ class BertForTokenClassification(BertPreTrainedModel): self.classifier = nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): - outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): + outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask) if self.output_attentions: all_attentions, sequence_output, _ = outputs else: @@ -1259,8 +1280,10 @@ class BertForQuestionAnswering(BertPreTrainedModel): self.qa_outputs = nn.Linear(config.hidden_size, 2) self.apply(self.init_bert_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): - outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, head_mask=None): + outputs = self.bert(input_ids, token_type_ids, attention_mask, + output_all_encoded_layers=False, + head_mask=head_mask) if self.output_attentions: all_attentions, sequence_output, _ = outputs else: diff --git a/pytorch_pretrained_bert/tokenization.py b/pytorch_pretrained_bert/tokenization.py index 26c172dc69..9a700cef0f 100644 --- a/pytorch_pretrained_bert/tokenization.py +++ b/pytorch_pretrained_bert/tokenization.py @@ -35,6 +35,8 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = { 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", + 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", + 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", } PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 'bert-base-uncased': 512, @@ -45,6 +47,8 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 'bert-base-multilingual-cased': 512, 'bert-base-chinese': 512, 'bert-base-german-cased': 512, + 'bert-large-uncased-whole-word-masking': 512, + 'bert-large-cased-whole-word-masking': 512, } VOCAB_NAME = 'vocab.txt'