This commit is contained in:
thomwolf
2019-10-14 12:14:40 +02:00
parent b7141a1bc6
commit d9d387afce

View File

@@ -127,12 +127,8 @@ class PreTrainedSeq2seq(nn.Module):
decoder = decoder_model decoder = decoder_model
else: else:
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc... kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
kwargs['is_decoder'] = True # Make sure the decoder will be an decoder kwargs['is_decoder'] = True # Make sure the decoder will be a decoder
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
else:
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'".format(decoder_pretrained_model_name_or_path))
model = cls(encoder, decoder) model = cls(encoder, decoder)
return model return model