From fdb120805c93d101efc03ec716e9153562054db7 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 20 Jun 2022 08:43:36 -0400 Subject: [PATCH] Fix cache for GPT-Neo-X (#17764) * Fix cache for GPT-Neo-X * Add more tests --- src/transformers/modeling_utils.py | 2 +- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 2 +- tests/models/gpt_neox/test_modeling_gpt_neox.py | 8 ++++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 09bf919992..b95bdbe958 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2458,7 +2458,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if offload_state_dict: # Load back temporarily offloaded state dict - load_offloaded_weights(model, state_dict_index, state_dict_folder) + load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder) shutil.rmtree(state_dict_folder) if len(error_msgs) > 0: diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 8a1879a624..7ca3d60e09 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -143,7 +143,7 @@ class GPTNeoXAttention(nn.Module): past_value = layer_past[1] key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) - present = None if use_cache else (key, value) + present = (key, value) if use_cache else None # Compute attention attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index a4fb95384e..a22fd78fc8 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -218,6 +218,14 @@ class GPTNeoXModelTest(ModelTesterMixin, unittest.TestCase): self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask) + def test_decoder_model_past_large_inputs(self): + config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(config, input_ids, input_mask) + + def test_model_for_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: