Merge branch 'master' into fix-ctrl-past
This commit is contained in:
@@ -221,7 +221,8 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs:
|
||||
**past**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `past` output below). Can be used to speed up sequential decoding.
|
||||
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
@@ -237,6 +238,10 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
@@ -249,7 +254,8 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
**past**:
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks).
|
||||
Can be used (see `past` input) to speed up sequential decoding.
|
||||
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
@@ -303,17 +309,26 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.h[layer].attn.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if past is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
else:
|
||||
past_length = past[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
# Attention mask.
|
||||
if attention_mask is not None:
|
||||
@@ -355,9 +370,10 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
token_type_embeds = 0
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
inputs_embeds = self.w(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.w(input_ids)
|
||||
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
||||
seq_len = input_ids.shape[-1]
|
||||
seq_len = input_shape.shape[-1]
|
||||
mask = torch.triu(torch.ones(seq_len + past_length, seq_len + past_length), 1).to(inputs_embeds.device)
|
||||
|
||||
inputs_embeds *= np.sqrt(self.d_model_size)
|
||||
@@ -424,7 +440,8 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
**past**:
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks).
|
||||
Can be used (see `past` input) to speed up sequential decoding.
|
||||
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
@@ -456,14 +473,15 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||
labels=None):
|
||||
transformer_outputs = self.transformer(input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user