Fix cache for GPT-Neo-X (#17764)
* Fix cache for GPT-Neo-X * Add more tests
This commit is contained in:
@@ -2458,7 +2458,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
if offload_state_dict:
|
||||
# Load back temporarily offloaded state dict
|
||||
load_offloaded_weights(model, state_dict_index, state_dict_folder)
|
||||
load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder)
|
||||
shutil.rmtree(state_dict_folder)
|
||||
|
||||
if len(error_msgs) > 0:
|
||||
|
||||
@@ -143,7 +143,7 @@ class GPTNeoXAttention(nn.Module):
|
||||
past_value = layer_past[1]
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
present = None if use_cache else (key, value)
|
||||
present = (key, value) if use_cache else None
|
||||
|
||||
# Compute attention
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
@@ -218,6 +218,14 @@ class GPTNeoXModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask)
|
||||
|
||||
def test_decoder_model_past_large_inputs(self):
|
||||
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(config, input_ids, input_mask)
|
||||
|
||||
def test_model_for_causal_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
||||
Reference in New Issue
Block a user