Fix attn mask ignore logic in training-time trace (#32613)
* fix attn mask logic for training-time trace * add test * fix * fix * fix * fix * fix * format * [run-slow] llama * avoid accelearate * [run-slow] llama
This commit is contained in:
@@ -4937,6 +4937,49 @@ class ModelTesterMixin:
|
||||
for i in range(n_iter):
|
||||
_ = model.generate(**input_ids, do_sample=False)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_torch_compile_for_training(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
if not hasattr(self, "_torch_compile_train_cls"):
|
||||
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_train_cls`.")
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
cls = self._torch_compile_train_cls
|
||||
model = cls(config).to(torch_device)
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||
"attention_mask": torch.tensor(
|
||||
[[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
dtype=torch.int64,
|
||||
device=torch_device,
|
||||
),
|
||||
"position_ids": torch.arange(0, 10, device=torch_device).unsqueeze(0),
|
||||
"labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||
}
|
||||
|
||||
# eager backward
|
||||
set_seed(42)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
params = {name: param.grad.clone().detach().cpu() for name, param in model.named_parameters()}
|
||||
model.zero_grad()
|
||||
del loss
|
||||
|
||||
model = torch.compile(model, fullgraph=True, mode="reduce-overhead")
|
||||
# forward compilation
|
||||
set_seed(42)
|
||||
loss = model(**inputs).loss
|
||||
# backward compilation
|
||||
loss.backward()
|
||||
# check grad matches
|
||||
for name, param in model._orig_mod.named_parameters():
|
||||
torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu # Testing cuda graphs.
|
||||
@require_read_token
|
||||
|
||||
Reference in New Issue
Block a user