Fix GPT-NeoX-20B past handling, attention computation (#17811)
* Fix GPT-NeoX-20B past handling, swap attention computation to hopefully avoid NaN, update docs * 20B tests
This commit is contained in:
@@ -38,32 +38,28 @@ class GPTNeoXConfig(PretrainedConfig):
|
|||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocab_size (`int`, *optional*, defaults to 30522):
|
vocab_size (`int`, *optional*, defaults to 50432):
|
||||||
Vocabulary size of the GPTNeoX model. Defines the number of different tokens that can be represented by the
|
Vocabulary size of the GPTNeoX model. Defines the number of different tokens that can be represented by the
|
||||||
`inputs_ids` passed when calling [`GPTNeoXModel`].
|
`inputs_ids` passed when calling [`GPTNeoXModel`].
|
||||||
hidden_size (`int`, *optional*, defaults to 768):
|
hidden_size (`int`, *optional*, defaults to 6144):
|
||||||
Dimension of the encoder layers and the pooler layer.
|
Dimension of the encoder layers and the pooler layer.
|
||||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
num_hidden_layers (`int`, *optional*, defaults to 44):
|
||||||
Number of hidden layers in the Transformer encoder.
|
Number of hidden layers in the Transformer encoder.
|
||||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
num_attention_heads (`int`, *optional*, defaults to 64):
|
||||||
Number of attention heads for each attention layer in the Transformer encoder.
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
intermediate_size (`int`, *optional*, defaults to 24576):
|
||||||
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
|
||||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
|
||||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
|
||||||
The dropout ratio for the attention probabilities.
|
|
||||||
rotary_pct (`float`, *optional*, defaults to 0.25):
|
rotary_pct (`float`, *optional*, defaults to 0.25):
|
||||||
percentage of hidden dimensions to allocate to rotary embeddings
|
percentage of hidden dimensions to allocate to rotary embeddings
|
||||||
rotary_emb_base (`int`, *optional*, defaults to 10000)
|
rotary_emb_base (`int`, *optional*, defaults to 10000)
|
||||||
base for computing rotary embeddings frequency
|
base for computing rotary embeddings frequency
|
||||||
max_position_embeddings (`int`, *optional*, defaults to 512):
|
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||||
just in case (e.g., 512 or 1024 or 2048).
|
just in case (e.g., 512 or 1024 or 2048).
|
||||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
initializer_range (`float`, *optional*, defaults to 1e-5):
|
||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||||
The epsilon used by the layer normalization layers.
|
The epsilon used by the layer normalization layers.
|
||||||
@@ -94,8 +90,6 @@ class GPTNeoXConfig(PretrainedConfig):
|
|||||||
num_attention_heads=64,
|
num_attention_heads=64,
|
||||||
intermediate_size=24576,
|
intermediate_size=24576,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
rotary_pct=0.25,
|
rotary_pct=0.25,
|
||||||
rotary_emb_base=10000,
|
rotary_emb_base=10000,
|
||||||
max_position_embeddings=2048,
|
max_position_embeddings=2048,
|
||||||
@@ -115,8 +109,6 @@ class GPTNeoXConfig(PretrainedConfig):
|
|||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.hidden_act = hidden_act
|
self.hidden_act = hidden_act
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.rotary_pct = rotary_pct
|
self.rotary_pct = rotary_pct
|
||||||
self.rotary_emb_base = rotary_emb_base
|
self.rotary_emb_base = rotary_emb_base
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
|
|||||||
@@ -195,7 +195,20 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
|
|
||||||
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
||||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||||
attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor
|
attn_scores = torch.zeros(
|
||||||
|
batch_size * num_attention_heads,
|
||||||
|
query_length,
|
||||||
|
key_length,
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=key.device,
|
||||||
|
)
|
||||||
|
attn_scores = torch.baddbmm(
|
||||||
|
attn_scores,
|
||||||
|
query,
|
||||||
|
key.transpose(1, 2),
|
||||||
|
beta=1.0,
|
||||||
|
alpha=(1.0 / self.norm_factor),
|
||||||
|
)
|
||||||
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
||||||
|
|
||||||
mask_value = torch.finfo(attn_scores.dtype).min
|
mask_value = torch.finfo(attn_scores.dtype).min
|
||||||
@@ -637,7 +650,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
attention_mask = input_ids.new_ones(input_shape)
|
attention_mask = input_ids.new_ones(input_shape)
|
||||||
|
|
||||||
# cut decoder_input_ids if past is used
|
# cut decoder_input_ids if past is used
|
||||||
if past is not None:
|
if past and past[0] is not None:
|
||||||
input_ids = input_ids[:, -1:]
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||||
|
|||||||
@@ -226,6 +226,10 @@ class GPTNeoXModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Feed forward chunking is not implemented")
|
||||||
|
def test_feed_forward_chunking(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
@@ -247,7 +251,7 @@ class GPTNeoXModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(output.shape, expected_shape)
|
self.assertEqual(output.shape, expected_shape)
|
||||||
|
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
[[[33.8045, 2.3958, 34.2816], [63.7805, 4.8332, 63.5882], [66.9116, 5.2198, 63.1185]]]
|
[[[33.5938, 2.3789, 34.0312], [63.4688, 4.8164, 63.3438], [66.8750, 5.2422, 63.0625]]]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||||
|
|||||||
Reference in New Issue
Block a user