clean up
This commit is contained in:
@@ -127,12 +127,8 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
decoder = decoder_model
|
decoder = decoder_model
|
||||||
else:
|
else:
|
||||||
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
|
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
|
||||||
kwargs['is_decoder'] = True # Make sure the decoder will be an decoder
|
kwargs['is_decoder'] = True # Make sure the decoder will be a decoder
|
||||||
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||||
else:
|
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
|
||||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
|
||||||
"'xlm', 'roberta'".format(decoder_pretrained_model_name_or_path))
|
|
||||||
|
|
||||||
model = cls(encoder, decoder)
|
model = cls(encoder, decoder)
|
||||||
return model
|
return model
|
||||||
|
|||||||
Reference in New Issue
Block a user