docstring + check
This commit is contained in:
@@ -313,6 +313,9 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want to input a probability distribution of tokens rather than actual tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
@@ -371,7 +374,9 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
self.h[layer].attn.prune_heads(heads)
|
self.h[layer].attn.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
|
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
|
||||||
if input_ids is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
r""" Base class for all TF models.
|
r""" Base class for all TF models.
|
||||||
|
|
||||||
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
||||||
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
||||||
|
|
||||||
Class attributes (overridden by derived classes):
|
Class attributes (overridden by derived classes):
|
||||||
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
r""" Base class for all models.
|
r""" Base class for all models.
|
||||||
|
|
||||||
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
||||||
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
||||||
|
|
||||||
Class attributes (overridden by derived classes):
|
Class attributes (overridden by derived classes):
|
||||||
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||||
|
|||||||
Reference in New Issue
Block a user