Fixup 9d0603148b
This commit is contained in:
@@ -139,6 +139,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
|||||||
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
|
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
|
||||||
|
|
||||||
our_output = model(input_ids)[0]
|
our_output = model(input_ids)[0]
|
||||||
|
if classification_head:
|
||||||
|
their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids))
|
||||||
|
else:
|
||||||
their_output = roberta.model(input_ids)[0]
|
their_output = roberta.model(input_ids)[0]
|
||||||
print(our_output.shape, their_output.shape)
|
print(our_output.shape, their_output.shape)
|
||||||
success = torch.allclose(our_output, their_output, atol=1e-3)
|
success = torch.allclose(our_output, their_output, atol=1e-3)
|
||||||
|
|||||||
Reference in New Issue
Block a user