Allow text generation for ProphetNetForCausalLM (#9707)
* Moved ProphetNetForCausalLM's parent initialization after config update * Added unit tests for generation for ProphetNetForCausalLM
This commit is contained in:
@@ -1883,11 +1883,11 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
||||
)
|
||||
class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
# set config for CLM
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
config.is_encoder_decoder = False
|
||||
super().__init__(config)
|
||||
self.prophetnet = ProphetNetDecoderWrapper(config)
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
|
||||
Reference in New Issue
Block a user