From b3d83d68db2db037a439516c24c593d4a85035a7 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 12 Aug 2019 12:28:55 -0400 Subject: [PATCH] Fixup 9d0603148bc34255fad0cad73ce438ecd7306322 --- .../convert_roberta_checkpoint_to_pytorch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_transformers/convert_roberta_checkpoint_to_pytorch.py b/pytorch_transformers/convert_roberta_checkpoint_to_pytorch.py index e4e8fbb25d..0a8967426e 100644 --- a/pytorch_transformers/convert_roberta_checkpoint_to_pytorch.py +++ b/pytorch_transformers/convert_roberta_checkpoint_to_pytorch.py @@ -139,7 +139,10 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1 our_output = model(input_ids)[0] - their_output = roberta.model(input_ids)[0] + if classification_head: + their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids)) + else: + their_output = roberta.model(input_ids)[0] print(our_output.shape, their_output.shape) success = torch.allclose(our_output, their_output, atol=1e-3) print(