Allow text generation for ProphetNetForCausalLM (#9707)

* Moved ProphetNetForCausalLM's parent initialization after config update

* Added unit tests for generation for ProphetNetForCausalLM
This commit is contained in:
guillaume-be
2021-01-21 11:13:38 +01:00
committed by GitHub
parent 910aa89671
commit fb36c273a2
2 changed files with 23 additions and 1 deletions

View File

@@ -1883,11 +1883,11 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
)
class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
# set config for CLM
config = copy.deepcopy(config)
config.is_decoder = True
config.is_encoder_decoder = False
super().__init__(config)
self.prophetnet = ProphetNetDecoderWrapper(config)
self.padding_idx = config.pad_token_id