adding bert whole words, bertgerman and gpt-2 medium models, head masking
This commit is contained in:
@@ -492,9 +492,12 @@ where
|
||||
- `bert-base-multilingual-cased`: **(New, recommended)** 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
|
||||
- `bert-base-chinese`: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters
|
||||
- `bert-base-german-cased`: Trained on German data only, 12-layer, 768-hidden, 12-heads, 110M parameters [Performance Evaluation](https://deepset.ai/german-bert)
|
||||
- `openai-gpt`: OpenAI English model, 12-layer, 768-hidden, 12-heads, 110M parameters
|
||||
- `transfo-xl-wt103`: Transformer-XL English model trained on wikitext-103, 18-layer, 1024-hidden, 16-heads, 257M parameters
|
||||
- `bert-large-uncased-whole-word-masking`: 24-layer, 1024-hidden, 16-heads, 340M parameters - Trained with Whole Word Masking (mask all of the the tokens corresponding to a word at once)
|
||||
- `bert-large-cased-whole-word-masking`: 24-layer, 1024-hidden, 16-heads, 340M parameters - Trained with Whole Word Masking (mask all of the the tokens corresponding to a word at once)
|
||||
- `openai-gpt`: OpenAI GPT English model, 12-layer, 768-hidden, 12-heads, 110M parameters
|
||||
- `gpt2`: OpenAI GPT-2 English model, 12-layer, 768-hidden, 12-heads, 117M parameters
|
||||
- `gpt2-medium`: OpenAI GPT-2 English model, 24-layer, 1024-hidden, 16-heads, 345M parameters
|
||||
- `transfo-xl-wt103`: Transformer-XL English model trained on wikitext-103, 18-layer, 1024-hidden, 16-heads, 257M parameters
|
||||
|
||||
- a path or url to a pretrained model archive containing:
|
||||
|
||||
|
||||
@@ -45,6 +45,8 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
||||
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased.tar.gz",
|
||||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking.tar.gz",
|
||||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking.tar.gz",
|
||||
}
|
||||
BERT_CONFIG_NAME = 'bert_config.json'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
@@ -279,13 +281,16 @@ class BertEmbeddings(nn.Module):
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Module):
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
||||
super(BertSelfAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
self.output_attentions = output_attentions
|
||||
self.keep_multihead_output = keep_multihead_output
|
||||
self.multihead_output = None
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@@ -301,7 +306,7 @@ class BertSelfAttention(nn.Module):
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
def forward(self, hidden_states, attention_mask, head_mask=None):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
@@ -323,7 +328,20 @@ class BertSelfAttention(nn.Module):
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask.unsqueeze(-1).unsqueeze(-1) # We can define heads to mask for each instance in the batch
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
if self.keep_multihead_output:
|
||||
self.multihead_output = context_layer
|
||||
self.multihead_output.retain_grad()
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
@@ -353,8 +371,8 @@ class BertAttention(nn.Module):
|
||||
self.self = BertSelfAttention(config, output_attentions=output_attentions)
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def forward(self, input_tensor, attention_mask):
|
||||
self_output = self.self(input_tensor, attention_mask)
|
||||
def forward(self, input_tensor, attention_mask, head_mask=None):
|
||||
self_output = self.self(input_tensor, attention_mask, head_mask)
|
||||
if self.output_attentions:
|
||||
attentions, self_output = self_output
|
||||
attention_output = self.output(self_output, input_tensor)
|
||||
@@ -400,8 +418,8 @@ class BertLayer(nn.Module):
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
def forward(self, hidden_states, attention_mask, head_mask=None):
|
||||
attention_output = self.attention(hidden_states, attention_mask, head_mask)
|
||||
if self.output_attentions:
|
||||
attentions, attention_output = attention_output
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
@@ -418,11 +436,11 @@ class BertEncoder(nn.Module):
|
||||
layer = BertLayer(config, output_attentions=output_attentions)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, head_mask=None):
|
||||
all_encoder_layers = []
|
||||
all_attentions = []
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(hidden_states, attention_mask)
|
||||
hidden_states = layer_module(hidden_states, attention_mask, head_mask)
|
||||
if self.output_attentions:
|
||||
attentions, hidden_states = hidden_states
|
||||
all_attentions.append(attentions)
|
||||
@@ -731,7 +749,7 @@ class BertModel(BertPreTrainedModel):
|
||||
self.pooler = BertPooler(config)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, head_mask=None):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
@@ -755,7 +773,8 @@ class BertModel(BertPreTrainedModel):
|
||||
embedding_output = self.embeddings(input_ids, token_type_ids)
|
||||
encoded_layers = self.encoder(embedding_output,
|
||||
extended_attention_mask,
|
||||
output_all_encoded_layers=output_all_encoded_layers)
|
||||
output_all_encoded_layers=output_all_encoded_layers,
|
||||
head_mask=head_mask)
|
||||
if self.output_attentions:
|
||||
all_attentions, encoded_layers = encoded_layers
|
||||
sequence_output = encoded_layers[-1]
|
||||
@@ -824,9 +843,9 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None):
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask,
|
||||
output_all_encoded_layers=False)
|
||||
output_all_encoded_layers=False, head_mask=head_mask)
|
||||
if self.output_attentions:
|
||||
all_attentions, sequence_output, pooled_output = outputs
|
||||
else:
|
||||
@@ -893,9 +912,10 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask,
|
||||
output_all_encoded_layers=False)
|
||||
output_all_encoded_layers=False,
|
||||
head_mask=head_mask)
|
||||
if self.output_attentions:
|
||||
all_attentions, sequence_output, _ = outputs
|
||||
else:
|
||||
@@ -961,9 +981,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
self.cls = BertOnlyNSPHead(config)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask,
|
||||
output_all_encoded_layers=False)
|
||||
output_all_encoded_layers=False,
|
||||
head_mask=head_mask)
|
||||
if self.output_attentions:
|
||||
all_attentions, _, pooled_output = outputs
|
||||
else:
|
||||
@@ -1033,8 +1054,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask)
|
||||
if self.output_attentions:
|
||||
all_attentions, _, pooled_output = outputs
|
||||
else:
|
||||
@@ -1104,11 +1125,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
|
||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||
outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
|
||||
outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False, head_mask=head_mask)
|
||||
if self.output_attentions:
|
||||
all_attentions, _, pooled_output = outputs
|
||||
else:
|
||||
@@ -1180,8 +1201,8 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask)
|
||||
if self.output_attentions:
|
||||
all_attentions, sequence_output, _ = outputs
|
||||
else:
|
||||
@@ -1259,8 +1280,10 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask,
|
||||
output_all_encoded_layers=False,
|
||||
head_mask=head_mask)
|
||||
if self.output_attentions:
|
||||
all_attentions, sequence_output, _ = outputs
|
||||
else:
|
||||
|
||||
@@ -35,6 +35,8 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
||||
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
|
||||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
|
||||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
|
||||
}
|
||||
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
'bert-base-uncased': 512,
|
||||
@@ -45,6 +47,8 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
'bert-base-multilingual-cased': 512,
|
||||
'bert-base-chinese': 512,
|
||||
'bert-base-german-cased': 512,
|
||||
'bert-large-uncased-whole-word-masking': 512,
|
||||
'bert-large-cased-whole-word-masking': 512,
|
||||
}
|
||||
VOCAB_NAME = 'vocab.txt'
|
||||
|
||||
|
||||
Reference in New Issue
Block a user