[inputs_embeds] All PyTorch models

This commit is contained in:
Julien Chaumond
2019-11-05 00:39:18 +00:00
parent 9eddf44b7a
commit 00337e9687
21 changed files with 361 additions and 147 deletions

View File

@@ -387,6 +387,10 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
@@ -436,9 +440,18 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.transformer.layer[layer].attention.prune_heads(heads)
def forward(self,
input_ids, attention_mask=None, head_mask=None):
input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids) # (bs, seq_length)
attention_mask = torch.ones(input_shape) # (bs, seq_length)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@@ -455,8 +468,9 @@ class DistilBertModel(DistilBertPreTrainedModel):
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids) # (bs, seq_length, dim)
tfmr_output = self.transformer(x=embedding_output,
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
tfmr_output = self.transformer(x=inputs_embeds,
attn_mask=attention_mask,
head_mask=head_mask)
hidden_state = tfmr_output[0]
@@ -514,10 +528,11 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def get_output_embeddings(self):
return self.vocab_projector
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, masked_lm_labels=None):
dlbrt_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
@@ -578,10 +593,11 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None):
distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
@@ -652,10 +668,11 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, head_mask=None, start_positions=None, end_positions=None):
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None):
distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)