[inputs_embeds] All PyTorch models
This commit is contained in:
@@ -387,6 +387,10 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
|
||||
@@ -436,9 +440,18 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
self.transformer.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(self,
|
||||
input_ids, attention_mask=None, head_mask=None):
|
||||
input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None):
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids) # (bs, seq_length)
|
||||
attention_mask = torch.ones(input_shape) # (bs, seq_length)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@@ -455,8 +468,9 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(input_ids) # (bs, seq_length, dim)
|
||||
tfmr_output = self.transformer(x=embedding_output,
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
|
||||
tfmr_output = self.transformer(x=inputs_embeds,
|
||||
attn_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
hidden_state = tfmr_output[0]
|
||||
@@ -514,10 +528,11 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.vocab_projector
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
|
||||
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, masked_lm_labels=None):
|
||||
dlbrt_output = self.distilbert(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
|
||||
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
|
||||
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
|
||||
@@ -578,10 +593,11 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, head_mask=None, labels=None):
|
||||
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
distilbert_output = self.distilbert(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
|
||||
pooled_output = hidden_state[:, 0] # (bs, dim)
|
||||
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
|
||||
@@ -652,10 +668,11 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, head_mask=None, start_positions=None, end_positions=None):
|
||||
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None):
|
||||
distilbert_output = self.distilbert(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
|
||||
|
||||
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
|
||||
|
||||
Reference in New Issue
Block a user