revert black formatting to conform with lib style
This commit is contained in:
@@ -43,13 +43,7 @@ class PreTrainedSeq2seq(PreTrainedModel):
|
|||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
cls,
|
|
||||||
encoder_pretrained_model_name_or_path,
|
|
||||||
decoder_pretrained_model_name_or_path,
|
|
||||||
*model_args,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
r""" Instantiates an encoder and a decoder from one or two base classes
|
r""" Instantiates an encoder and a decoder from one or two base classes
|
||||||
of the library from pre-trained model checkpoints.
|
of the library from pre-trained model checkpoints.
|
||||||
|
|
||||||
@@ -190,7 +184,7 @@ class Model2Model(PreTrainedSeq2seq):
|
|||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
""" Tying the encoder and decoders' embeddings together.
|
""" Tying the encoder and decoders' embeddings together.
|
||||||
|
|
||||||
We need for each to get down to the embedding weights. However the
|
We need for each to get down to the embedding weights. However the
|
||||||
different model classes are inconsistent to that respect:
|
different model classes are inconsistent to that respect:
|
||||||
- BertModel: embeddings.word_embeddings
|
- BertModel: embeddings.word_embeddings
|
||||||
- RoBERTa: embeddings.word_embeddings
|
- RoBERTa: embeddings.word_embeddings
|
||||||
|
|||||||
Reference in New Issue
Block a user