Delete all mentions of Model2Model (#3019)
This commit is contained in:
@@ -234,62 +234,3 @@ class PreTrainedEncoderDecoder(nn.Module):
|
||||
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
|
||||
class Model2Model(PreTrainedEncoderDecoder):
|
||||
r"""
|
||||
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
|
||||
where both of the encoder and decoder are of the same family. If the
|
||||
name of or that path to a pretrained model is specified the encoder and
|
||||
the decoder will be initialized with the pretrained weight (the
|
||||
cross-attention will be intialized randomly if its weights are not
|
||||
present).
|
||||
|
||||
It is possible to override this behavior and initialize, say, the decoder randomly
|
||||
by creating it beforehand as follows
|
||||
|
||||
config = BertConfig.from_pretrained()
|
||||
decoder = BertForMaskedLM(config)
|
||||
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Tying the encoder and decoders' embeddings together.
|
||||
|
||||
We need for each to get down to the embedding weights. However the
|
||||
different model classes are inconsistent to that respect:
|
||||
- BertModel: embeddings.word_embeddings
|
||||
- RoBERTa: embeddings.word_embeddings
|
||||
- XLMModel: embeddings
|
||||
- GPT2: wte
|
||||
- BertForMaskedLM: bert.embeddings.word_embeddings
|
||||
- RobertaForMaskedLM: roberta.embeddings.word_embeddings
|
||||
|
||||
argument of the XEmbedding layer for each model, but it is "blocked"
|
||||
by a model-specific keyword (bert, )...
|
||||
"""
|
||||
# self._tie_or_clone_weights(self.encoder, self.decoder)
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
||||
|
||||
if (
|
||||
"bert" not in pretrained_model_name_or_path
|
||||
or "roberta" in pretrained_model_name_or_path
|
||||
or "distilbert" in pretrained_model_name_or_path
|
||||
):
|
||||
raise ValueError("Only the Bert model is currently supported.")
|
||||
|
||||
model = super().from_pretrained(
|
||||
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user