[inputs_embeds] All PyTorch models

This commit is contained in:
Julien Chaumond
2019-11-05 00:39:18 +00:00
parent 9eddf44b7a
commit 00337e9687
21 changed files with 361 additions and 147 deletions

View File

@@ -255,6 +255,10 @@ XXX_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",

View File

@@ -238,6 +238,10 @@ XXX_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
@@ -295,7 +299,7 @@ class XxxModel(XxxPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
@@ -449,14 +453,15 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.transformer(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
pooled_output = outputs[1]
@@ -520,14 +525,15 @@ class XxxForTokenClassification(XxxPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.transformer(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0]
@@ -603,14 +609,15 @@ class XxxForQuestionAnswering(XxxPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None):
outputs = self.transformer(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0]