From 663c8512398f864cb886879bbb64471777421413 Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Tue, 5 Nov 2024 04:41:48 -0800 Subject: [PATCH] DistilBERT is ExecuTorch compatible (#34475) * DistillBERT is ExecuTorch compatible * [run_slow] distilbert * [run_slow] distilbert --------- Co-authored-by: Guang Yang --- .../distilbert/test_modeling_distilbert.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/models/distilbert/test_modeling_distilbert.py b/tests/models/distilbert/test_modeling_distilbert.py index 3a74a1557c..d4c51cea12 100644 --- a/tests/models/distilbert/test_modeling_distilbert.py +++ b/tests/models/distilbert/test_modeling_distilbert.py @@ -30,6 +30,7 @@ if is_torch_available(): import torch from transformers import ( + AutoTokenizer, DistilBertForMaskedLM, DistilBertForMultipleChoice, DistilBertForQuestionAnswering, @@ -38,6 +39,7 @@ if is_torch_available(): DistilBertModel, ) from transformers.models.distilbert.modeling_distilbert import _create_sinusoidal_embeddings + from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 class DistilBertModelTester: @@ -420,3 +422,45 @@ class DistilBertModelIntergrationTest(unittest.TestCase): ) self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4)) + + @slow + def test_export(self): + if not is_torch_greater_or_equal_than_2_4: + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + distilbert_model = "distilbert-base-uncased" + device = "cpu" + attn_implementation = "sdpa" + max_length = 64 + + tokenizer = AutoTokenizer.from_pretrained(distilbert_model) + inputs = tokenizer( + f"Paris is the {tokenizer.mask_token} of France.", + return_tensors="pt", + padding="max_length", + max_length=max_length, + ) + + model = DistilBertForMaskedLM.from_pretrained( + distilbert_model, + device_map=device, + attn_implementation=attn_implementation, + ) + + logits = model(**inputs).logits + eager_predicted_mask = tokenizer.decode(logits[0, 4].topk(5).indices) + self.assertEqual( + eager_predicted_mask.split(), + ["capital", "birthplace", "northernmost", "centre", "southernmost"], + ) + + exported_program = torch.export.export( + model, + args=(inputs["input_ids"],), + kwargs={"attention_mask": inputs["attention_mask"]}, + strict=True, + ) + + result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"]) + exported_predicted_mask = tokenizer.decode(result.logits[0, 4].topk(5).indices) + self.assertEqual(eager_predicted_mask, exported_predicted_mask)