Fixup 9d0603148b
This commit is contained in:
@@ -139,6 +139,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
|
||||
|
||||
our_output = model(input_ids)[0]
|
||||
if classification_head:
|
||||
their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids))
|
||||
else:
|
||||
their_output = roberta.model(input_ids)[0]
|
||||
print(our_output.shape, their_output.shape)
|
||||
success = torch.allclose(our_output, their_output, atol=1e-3)
|
||||
|
||||
Reference in New Issue
Block a user