Fix attn mask ignore logic in training-time trace (#32613)

* fix attn mask logic for training-time trace

* add test

* fix

* fix

* fix

* fix

* fix

* format

* [run-slow] llama

* avoid accelearate

* [run-slow] llama
This commit is contained in:
Longjie Zheng
2024-10-04 13:00:45 -04:00
committed by GitHub
parent 614660fdb9
commit 0d1692a49b
7 changed files with 55 additions and 5 deletions

View File

@@ -4937,6 +4937,49 @@ class ModelTesterMixin:
for i in range(n_iter):
_ = model.generate(**input_ids, do_sample=False)
@slow
@require_torch_gpu
def test_torch_compile_for_training(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
if not hasattr(self, "_torch_compile_train_cls"):
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_train_cls`.")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
cls = self._torch_compile_train_cls
model = cls(config).to(torch_device)
inputs = {
"input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
"attention_mask": torch.tensor(
[[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.int64,
device=torch_device,
),
"position_ids": torch.arange(0, 10, device=torch_device).unsqueeze(0),
"labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
}
# eager backward
set_seed(42)
loss = model(**inputs).loss
loss.backward()
params = {name: param.grad.clone().detach().cpu() for name, param in model.named_parameters()}
model.zero_grad()
del loss
model = torch.compile(model, fullgraph=True, mode="reduce-overhead")
# forward compilation
set_seed(42)
loss = model(**inputs).loss
# backward compilation
loss.backward()
# check grad matches
for name, param in model._orig_mod.named_parameters():
torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4)
@slow
@require_torch_gpu # Testing cuda graphs.
@require_read_token