Skip warning if tracing with dynamo (#25581)
* Ignore warning if tracing with dynamo * fix import error * separate to function * add test
This commit is contained in:
@@ -55,6 +55,7 @@ from transformers.utils import (
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
)
|
||||
from transformers.utils.import_utils import is_torchdynamo_available
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
@@ -1014,6 +1015,25 @@ class ModelUtilsTest(TestCasePlus):
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out)
|
||||
|
||||
if not is_torchdynamo_available():
|
||||
return
|
||||
with self.subTest("Ensure that the warning code is skipped when compiling with torchdynamo."):
|
||||
logger.warning_once.cache_clear()
|
||||
from torch._dynamo import config, testing
|
||||
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])
|
||||
|
||||
def f(input_ids):
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
|
||||
compile_counter = testing.CompileCounter()
|
||||
opt_fn = torch.compile(f, dynamic=True, backend=compile_counter)
|
||||
opt_fn(input_ids)
|
||||
self.assertEqual(compile_counter.frame_count, 0)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_pretrained_low_mem_new_config(self):
|
||||
|
||||
Reference in New Issue
Block a user