From ad5e9b6c6adaf630304d1f5ce11d6377d40bba4f Mon Sep 17 00:00:00 2001 From: TheWall9 Date: Tue, 4 Apr 2023 18:50:33 +0800 Subject: [PATCH] [Roformer] Fixing a bug in RoFormerEncoder where it was ignoring the length of past_key_values when generating as a decoder (#22416) * fix RoFormerEncoder postion embedding when generate as decoder * make fixup * add test case for check generate with past key values * remove duplicating code --- .../models/roformer/modeling_roformer.py | 12 +++++----- .../models/roformer/test_modeling_roformer.py | 23 +++++++++++++++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index e3e1d4b6be..d6752f98d7 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -259,11 +259,6 @@ class RoFormerSelfAttention(nn.Module): key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) @@ -276,6 +271,9 @@ class RoFormerSelfAttention(nn.Module): query_layer, key_layer = self.apply_rotary_position_embeddings( sinusoidal_pos, query_layer, key_layer ) + if past_key_value is not None: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -566,8 +564,10 @@ class RoFormerEncoder(nn.Module): all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + # [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head] - sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1])[None, None, :, :] + sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1], past_key_values_length)[None, None, :, :] next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): diff --git a/tests/models/roformer/test_modeling_roformer.py b/tests/models/roformer/test_modeling_roformer.py index 9b91c66cb2..357e126a04 100644 --- a/tests/models/roformer/test_modeling_roformer.py +++ b/tests/models/roformer/test_modeling_roformer.py @@ -220,6 +220,25 @@ class RoFormerModelTester: result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_for_generate_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = RoFormerForCausalLM(config=config).to(torch_device).eval() + torch.manual_seed(0) + output_without_past_cache = model.generate( + input_ids[:1], num_beams=2, max_length=15, do_sample=True, use_cache=False + ) + torch.manual_seed(0) + output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=15, do_sample=True) + self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) + def create_and_check_for_masked_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -405,6 +424,10 @@ class RoFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + def test_for_generate_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_generate_causal_lm(*config_and_inputs) + def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)