From 5392f12e1614383270ae8df524415a1f6b555773 Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Tue, 29 Oct 2024 06:30:02 -0700 Subject: [PATCH] Bert is ExecuTorch compatible (#34424) Co-authored-by: Guang Yang --- tests/models/bert/test_modeling_bert.py | 42 +++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index 8ac1c3d2b4..aa9835d8cd 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -16,6 +16,8 @@ import os import tempfile import unittest +from packaging import version + from transformers import AutoTokenizer, BertConfig, is_torch_available from transformers.models.auto import get_values from transformers.testing_utils import ( @@ -749,3 +751,43 @@ class BertModelIntegrationTest(unittest.TestCase): self.assertTrue( torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4) ) + + @slow + def test_export(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + bert_model = "google-bert/bert-base-uncased" + device = "cpu" + attn_implementation = "sdpa" + max_length = 512 + + tokenizer = AutoTokenizer.from_pretrained(bert_model) + inputs = tokenizer( + "the man worked as a [MASK].", + return_tensors="pt", + padding="max_length", + max_length=max_length, + ) + + model = BertForMaskedLM.from_pretrained( + bert_model, + device_map=device, + attn_implementation=attn_implementation, + use_cache=True, + ) + + logits = model(**inputs).logits + eg_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices) + self.assertEqual(eg_predicted_mask.split(), ["carpenter", "waiter", "barber", "mechanic", "salesman"]) + + exported_program = torch.export.export( + model, + args=(inputs["input_ids"],), + kwargs={"attention_mask": inputs["attention_mask"]}, + strict=True, + ) + + result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"]) + ep_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices) + self.assertEqual(eg_predicted_mask, ep_predicted_mask)