clean up
This commit is contained in:
@@ -127,12 +127,8 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
decoder = decoder_model
|
||||
else:
|
||||
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
|
||||
kwargs['is_decoder'] = True # Make sure the decoder will be an decoder
|
||||
kwargs['is_decoder'] = True # Make sure the decoder will be a decoder
|
||||
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
||||
else:
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm', 'roberta'".format(decoder_pretrained_model_name_or_path))
|
||||
|
||||
model = cls(encoder, decoder)
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user