Update to match renamed attributes in fairseq master (#5972)
* Update to match renamed attributes in fairseq master RobertaModel no longer have model.encoder and args.num_classes attributes as of 5/28/20. * Quality Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -47,7 +47,7 @@ def convert_roberta_checkpoint_to_pytorch(
|
|||||||
"""
|
"""
|
||||||
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
|
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
|
||||||
roberta.eval() # disable dropout
|
roberta.eval() # disable dropout
|
||||||
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
|
roberta_sent_encoder = roberta.model.encoder.sentence_encoder
|
||||||
config = RobertaConfig(
|
config = RobertaConfig(
|
||||||
vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings,
|
vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings,
|
||||||
hidden_size=roberta.args.encoder_embed_dim,
|
hidden_size=roberta.args.encoder_embed_dim,
|
||||||
@@ -59,7 +59,7 @@ def convert_roberta_checkpoint_to_pytorch(
|
|||||||
layer_norm_eps=1e-5, # PyTorch default used in fairseq
|
layer_norm_eps=1e-5, # PyTorch default used in fairseq
|
||||||
)
|
)
|
||||||
if classification_head:
|
if classification_head:
|
||||||
config.num_labels = roberta.args.num_classes
|
config.num_labels = roberta.model.classification_heads["mnli"].out_proj.weight.shape[0]
|
||||||
print("Our BERT config:", config)
|
print("Our BERT config:", config)
|
||||||
|
|
||||||
model = RobertaForSequenceClassification(config) if classification_head else RobertaForMaskedLM(config)
|
model = RobertaForSequenceClassification(config) if classification_head else RobertaForMaskedLM(config)
|
||||||
@@ -126,12 +126,12 @@ def convert_roberta_checkpoint_to_pytorch(
|
|||||||
model.classifier.out_proj.bias = roberta.model.classification_heads["mnli"].out_proj.bias
|
model.classifier.out_proj.bias = roberta.model.classification_heads["mnli"].out_proj.bias
|
||||||
else:
|
else:
|
||||||
# LM Head
|
# LM Head
|
||||||
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
|
model.lm_head.dense.weight = roberta.model.encoder.lm_head.dense.weight
|
||||||
model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias
|
model.lm_head.dense.bias = roberta.model.encoder.lm_head.dense.bias
|
||||||
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
|
model.lm_head.layer_norm.weight = roberta.model.encoder.lm_head.layer_norm.weight
|
||||||
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
|
model.lm_head.layer_norm.bias = roberta.model.encoder.lm_head.layer_norm.bias
|
||||||
model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight
|
model.lm_head.decoder.weight = roberta.model.encoder.lm_head.weight
|
||||||
model.lm_head.decoder.bias = roberta.model.decoder.lm_head.bias
|
model.lm_head.decoder.bias = roberta.model.encoder.lm_head.bias
|
||||||
|
|
||||||
# Let's check that we get the same results.
|
# Let's check that we get the same results.
|
||||||
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
|
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
|
||||||
|
|||||||
Reference in New Issue
Block a user