This commit is contained in:
Julien Chaumond
2019-08-12 12:28:55 -04:00
parent 75d5f98fd2
commit b3d83d68db

View File

@@ -139,7 +139,10 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
our_output = model(input_ids)[0]
their_output = roberta.model(input_ids)[0]
if classification_head:
their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids))
else:
their_output = roberta.model(input_ids)[0]
print(our_output.shape, their_output.shape)
success = torch.allclose(our_output, their_output, atol=1e-3)
print(