Delete all mentions of Model2Model (#3019)

This commit is contained in:
Sam Shleifer
2020-02-26 11:36:27 -05:00
committed by GitHub
parent bb7c468520
commit 9df74b8bc4
4 changed files with 1 additions and 203 deletions

View File

@@ -234,62 +234,3 @@ class PreTrainedEncoderDecoder(nn.Module):
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
return decoder_outputs + encoder_outputs
class Model2Model(PreTrainedEncoderDecoder):
r"""
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
where both of the encoder and decoder are of the same family. If the
name of or that path to a pretrained model is specified the encoder and
the decoder will be initialized with the pretrained weight (the
cross-attention will be intialized randomly if its weights are not
present).
It is possible to override this behavior and initialize, say, the decoder randomly
by creating it beforehand as follows
config = BertConfig.from_pretrained()
decoder = BertForMaskedLM(config)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tie_weights()
def tie_weights(self):
""" Tying the encoder and decoders' embeddings together.
We need for each to get down to the embedding weights. However the
different model classes are inconsistent to that respect:
- BertModel: embeddings.word_embeddings
- RoBERTa: embeddings.word_embeddings
- XLMModel: embeddings
- GPT2: wte
- BertForMaskedLM: bert.embeddings.word_embeddings
- RobertaForMaskedLM: roberta.embeddings.word_embeddings
argument of the XEmbedding layer for each model, but it is "blocked"
by a model-specific keyword (bert, )...
"""
# self._tie_or_clone_weights(self.encoder, self.decoder)
pass
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
if (
"bert" not in pretrained_model_name_or_path
or "roberta" in pretrained_model_name_or_path
or "distilbert" in pretrained_model_name_or_path
):
raise ValueError("Only the Bert model is currently supported.")
model = super().from_pretrained(
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
*args,
**kwargs,
)
return model