[inputs_embeds] All PyTorch models
This commit is contained in:
@@ -158,19 +158,26 @@ class BertEmbeddings(nn.Module):
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
||||
seq_length = input_ids.size(1)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
else:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
seq_length = input_shape[1]
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
||||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
@@ -550,6 +557,10 @@ BERT_INPUTS_DOCSTRING = r"""
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
**encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
|
||||
is configured as a decoder.
|
||||
@@ -615,8 +626,8 @@ class BertModel(BertPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
|
||||
head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
|
||||
head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||
""" Forward pass on the Model.
|
||||
|
||||
The model can behave as an encoder (with only self-attention) as well
|
||||
@@ -632,12 +643,23 @@ class BertModel(BertPreTrainedModel):
|
||||
https://arxiv.org/abs/1706.03762
|
||||
|
||||
"""
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = torch.ones(input_shape)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones_like(input_ids)
|
||||
encoder_attention_mask = torch.ones(input_shape)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
@@ -649,8 +671,8 @@ class BertModel(BertPreTrainedModel):
|
||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if attention_mask.dim() == 2:
|
||||
if self.config.is_decoder:
|
||||
batch_size, seq_length = input_ids.size()
|
||||
seq_ids = torch.arange(seq_length, device=input_ids.device)
|
||||
batch_size, seq_length = input_shape
|
||||
seq_ids = torch.arange(seq_length, device=device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
else:
|
||||
@@ -689,7 +711,7 @@ class BertModel(BertPreTrainedModel):
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
|
||||
encoder_outputs = self.encoder(embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
@@ -754,14 +776,15 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||
masked_lm_labels=None, next_sentence_label=None):
|
||||
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||
@@ -829,7 +852,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
|
||||
|
||||
outputs = self.bert(input_ids,
|
||||
@@ -837,6 +860,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask)
|
||||
|
||||
@@ -908,14 +932,15 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||
next_sentence_label=None):
|
||||
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
@@ -975,14 +1000,15 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, labels=None):
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
@@ -1049,8 +1075,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, labels=None):
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
num_choices = input_ids.shape[1]
|
||||
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
@@ -1062,7 +1088,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
@@ -1123,14 +1150,15 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, labels=None):
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
@@ -1207,14 +1235,15 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||
start_positions=None, end_positions=None):
|
||||
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user