Skip warning if tracing with dynamo (#25581)

* Ignore warning if tracing with dynamo

* fix import error

* separate to function

* add test
This commit is contained in:
Angela Yi
2023-09-08 12:13:33 -07:00
committed by GitHub
parent 18ee1fe762
commit 6c26faa159
3 changed files with 38 additions and 2 deletions

View File

@@ -55,6 +55,7 @@ from transformers.utils import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from transformers.utils.import_utils import is_torchdynamo_available
sys.path.append(str(Path(__file__).parent.parent / "utils"))
@@ -1014,6 +1015,25 @@ class ModelUtilsTest(TestCasePlus):
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out)
if not is_torchdynamo_available():
return
with self.subTest("Ensure that the warning code is skipped when compiling with torchdynamo."):
logger.warning_once.cache_clear()
from torch._dynamo import config, testing
config = PretrainedConfig()
config.pad_token_id = 0
model = ModelWithHead(config)
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])
def f(input_ids):
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
compile_counter = testing.CompileCounter()
opt_fn = torch.compile(f, dynamic=True, backend=compile_counter)
opt_fn(input_ids)
self.assertEqual(compile_counter.frame_count, 0)
@require_torch_gpu
@slow
def test_pretrained_low_mem_new_config(self):