fixing CTRL tests and OpenAI GPT tests
This commit is contained in:
@@ -303,11 +303,6 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
@@ -349,42 +344,51 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
else:
|
||||
head_mask = [None] * self.config.n_layer
|
||||
|
||||
x = self.w(input_ids)
|
||||
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
token_type_embeds = self.w(token_type_ids)
|
||||
token_type_embeds *= np.sqrt(self.d_model_size)
|
||||
else:
|
||||
token_type_embeds = 0
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
inputs_embeds = self.w(input_ids)
|
||||
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
||||
seq_len = input_ids.shape[-1]
|
||||
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(x.device)
|
||||
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(inputs_embeds.device)
|
||||
|
||||
x *= np.sqrt(self.d_model_size)
|
||||
inputs_embeds *= np.sqrt(self.d_model_size)
|
||||
|
||||
pos_x = self.pos_encoding[position_ids, :].to(x.device)
|
||||
x += pos_x
|
||||
pos_embeds = self.pos_encoding[position_ids, :].to(inputs_embeds.device)
|
||||
|
||||
x = self.dropout(x)
|
||||
hidden_states = inputs_embeds + pos_embeds + token_type_embeds
|
||||
|
||||
output_shape = input_shape + (x.size(-1),)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
output_shape = input_shape + (inputs_embeds.size(-1),)
|
||||
presents = ()
|
||||
all_hidden_states = ()
|
||||
all_attentions = []
|
||||
for i, (h, layer_past) in enumerate(zip(self.h, past)):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (x.view(*output_shape),)
|
||||
outputs = h(x,
|
||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||
outputs = h(hidden_states,
|
||||
mask,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i])
|
||||
x, present = outputs[:2]
|
||||
hidden_states, present = outputs[:2]
|
||||
presents = presents + (present,)
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions.append(outputs[2])
|
||||
|
||||
x = self.layernorm(x)
|
||||
x = x.view(*output_shape)
|
||||
hidden_states = self.layernorm(hidden_states)
|
||||
hidden_states = hidden_states.view(*output_shape)
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (x,)
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (x, presents)
|
||||
outputs = (hidden_states, presents)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
|
||||
Reference in New Issue
Block a user