Skip warning if tracing with dynamo (#25581)
* Ignore warning if tracing with dynamo * fix import error * separate to function * add test
This commit is contained in:
@@ -81,7 +81,12 @@ from .utils import (
|
||||
strtobool,
|
||||
)
|
||||
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
||||
from .utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled, is_torch_fx_proxy
|
||||
from .utils.import_utils import (
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_torch_fx_proxy,
|
||||
is_torchdynamo_compiling,
|
||||
)
|
||||
from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig, QuantizationMethod
|
||||
from .utils.versions import require_version_core
|
||||
|
||||
@@ -3799,7 +3804,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"""
|
||||
|
||||
# Skip the check during tracing.
|
||||
if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing():
|
||||
if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing() or is_torchdynamo_compiling():
|
||||
return
|
||||
|
||||
if (attention_mask is not None) or (self.config.pad_token_id is None):
|
||||
|
||||
@@ -463,6 +463,17 @@ def is_torch_compile_available():
|
||||
return hasattr(torch, "compile")
|
||||
|
||||
|
||||
def is_torchdynamo_compiling():
|
||||
if not is_torch_available():
|
||||
return False
|
||||
try:
|
||||
import torch._dynamo as dynamo # noqa: F401
|
||||
|
||||
return dynamo.is_compiling()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_torch_tensorrt_fx_available():
|
||||
if importlib.util.find_spec("torch_tensorrt") is None:
|
||||
return False
|
||||
|
||||
@@ -55,6 +55,7 @@ from transformers.utils import (
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
)
|
||||
from transformers.utils.import_utils import is_torchdynamo_available
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
@@ -1014,6 +1015,25 @@ class ModelUtilsTest(TestCasePlus):
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out)
|
||||
|
||||
if not is_torchdynamo_available():
|
||||
return
|
||||
with self.subTest("Ensure that the warning code is skipped when compiling with torchdynamo."):
|
||||
logger.warning_once.cache_clear()
|
||||
from torch._dynamo import config, testing
|
||||
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])
|
||||
|
||||
def f(input_ids):
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
|
||||
compile_counter = testing.CompileCounter()
|
||||
opt_fn = torch.compile(f, dynamic=True, backend=compile_counter)
|
||||
opt_fn(input_ids)
|
||||
self.assertEqual(compile_counter.frame_count, 0)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_pretrained_low_mem_new_config(self):
|
||||
|
||||
Reference in New Issue
Block a user