Add support for torch.compile dynamic shapes (#30560)

* add torch.compile dynamic support

* Add SDPA dynamic shapes compile test & improve SDPA comment

* comment consistency
This commit is contained in:
Benjamin Warner
2024-05-20 03:36:57 -05:00
committed by GitHub
parent fce78fd0e9
commit cd6bd0af34
27 changed files with 190 additions and 60 deletions

View File

@@ -4014,6 +4014,47 @@ class ModelTesterMixin:
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
_ = model(**inputs_dict)
@require_torch_sdpa
@require_torch_gpu
@slow
def test_sdpa_can_compile_dynamic(self):
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability
if not torch.version.cuda or major < 8:
self.skipTest("This test requires an NVIDIA GPU with compute capability >= 8.0")
for model_class in self.all_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
if config.model_type in ["dbrx"]:
self.skipTest(
"DBRX (transformers==4.40) requires a modification to support dynamic shapes with compile."
)
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")
model.to(torch_device)
# For PyTorch 2.1 - 2.3.0 set `dynamic=True`. In the future setting `dynamic=None` and using `torch._dynamo.mark_dynamic()`
# on input tensors will be required. `mark_dynamic` currently raises inconsistent shape errors.
model = torch.compile(model, dynamic=True)
inputs_dict.pop("attention_mask", None)
inputs_dict.pop("decoder_attention_mask", None)
for name, inp in inputs_dict.items():
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
inputs_dict[name] = inp.to(torch.float16)
# use no_grad to save some memory
with torch.no_grad():
_ = model(**inputs_dict)
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):