Allow text generation for ProphetNetForCausalLM (#9707)
* Moved ProphetNetForCausalLM's parent initialization after config update * Added unit tests for generation for ProphetNetForCausalLM
This commit is contained in:
@@ -1883,11 +1883,11 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
|||||||
)
|
)
|
||||||
class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
|
||||||
# set config for CLM
|
# set config for CLM
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
config.is_encoder_decoder = False
|
config.is_encoder_decoder = False
|
||||||
|
super().__init__(config)
|
||||||
self.prophetnet = ProphetNetDecoderWrapper(config)
|
self.prophetnet = ProphetNetDecoderWrapper(config)
|
||||||
|
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
|
|||||||
@@ -302,6 +302,24 @@ class ProphetNetModelTester:
|
|||||||
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
|
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
|
||||||
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
|
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
|
||||||
|
|
||||||
|
def create_and_check_decoder_generate_with_past_key_value_states(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
decoder_input_ids,
|
||||||
|
attention_mask,
|
||||||
|
decoder_attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
model = ProphetNetForCausalLM(config=config).to(torch_device).eval()
|
||||||
|
torch.manual_seed(0)
|
||||||
|
output_without_past_cache = model.generate(
|
||||||
|
input_ids[:1], num_beams=2, max_length=10, do_sample=True, use_cache=False
|
||||||
|
)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=10, do_sample=True)
|
||||||
|
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
|
||||||
|
|
||||||
def create_and_check_model_fp16_forward(
|
def create_and_check_model_fp16_forward(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
@@ -911,6 +929,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_generate_with_past_key_value_states(*config_and_inputs)
|
self.model_tester.create_and_check_generate_with_past_key_value_states(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_encoder_decoder_model_generate(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_generate_with_past_key_value_states(*config_and_inputs)
|
||||||
|
|
||||||
def test_attn_mask_model(self):
|
def test_attn_mask_model(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_model_with_attn_mask(*config_and_inputs)
|
self.model_tester.check_model_with_attn_mask(*config_and_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user