[GPT2] Correct gradient checkpointing (#9308)
* correct gpt2 * fix gpt2 * fix use_cache ordering * correct past tolerance * fix for all cases * style
This commit is contained in:
committed by
GitHub
parent
21fc676645
commit
61443cd7d9
@@ -184,9 +184,9 @@ class Attention(nn.Module):
|
||||
if head_mask is not None:
|
||||
w = w * head_mask
|
||||
|
||||
outputs = [torch.matmul(w, v)]
|
||||
outputs = (torch.matmul(w, v),)
|
||||
if output_attentions:
|
||||
outputs.append(w)
|
||||
outputs += (w,)
|
||||
return outputs
|
||||
|
||||
def merge_heads(self, x):
|
||||
@@ -234,7 +234,7 @@ class Attention(nn.Module):
|
||||
if use_cache is True:
|
||||
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
||||
else:
|
||||
present = (None,)
|
||||
present = None
|
||||
|
||||
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
|
||||
a = attn_outputs[0]
|
||||
@@ -243,8 +243,7 @@ class Attention(nn.Module):
|
||||
a = self.c_proj(a)
|
||||
a = self.resid_dropout(a)
|
||||
|
||||
outputs = [a, present] + attn_outputs[1:]
|
||||
return outputs # a, present, (attentions)
|
||||
return (a, present) + attn_outputs[1:] # a, present, (attentions)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
@@ -321,7 +320,11 @@ class Block(nn.Module):
|
||||
# residual connection
|
||||
hidden_states = hidden_states + feed_forward_hidden_states
|
||||
|
||||
outputs = [hidden_states] + outputs
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:]
|
||||
|
||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
||||
|
||||
|
||||
@@ -740,14 +743,14 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states, present = outputs[:2]
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (present,)
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2],)
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3],)
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
|
||||
Reference in New Issue
Block a user