diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 41150e4c17..d473e8758a 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1883,11 +1883,11 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ) class ProphetNetForCausalLM(ProphetNetPreTrainedModel): def __init__(self, config): - super().__init__(config) # set config for CLM config = copy.deepcopy(config) config.is_decoder = True config.is_encoder_decoder = False + super().__init__(config) self.prophetnet = ProphetNetDecoderWrapper(config) self.padding_idx = config.pad_token_id diff --git a/tests/test_modeling_prophetnet.py b/tests/test_modeling_prophetnet.py index a88b2653c1..c9ba56396e 100644 --- a/tests/test_modeling_prophetnet.py +++ b/tests/test_modeling_prophetnet.py @@ -302,6 +302,24 @@ class ProphetNetModelTester: output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) + def create_and_check_decoder_generate_with_past_key_value_states( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = ProphetNetForCausalLM(config=config).to(torch_device).eval() + torch.manual_seed(0) + output_without_past_cache = model.generate( + input_ids[:1], num_beams=2, max_length=10, do_sample=True, use_cache=False + ) + torch.manual_seed(0) + output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=10, do_sample=True) + self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) + def create_and_check_model_fp16_forward( self, config, @@ -911,6 +929,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_generate_with_past_key_value_states(*config_and_inputs) + def test_encoder_decoder_model_generate(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_generate_with_past_key_value_states(*config_and_inputs) + def test_attn_mask_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_model_with_attn_mask(*config_and_inputs)