Fix attn mask ignore logic in training-time trace (#32613)
* fix attn mask logic for training-time trace * add test * fix * fix * fix * fix * fix * format * [run-slow] llama * avoid accelearate * [run-slow] llama
This commit is contained in:
@@ -321,6 +321,9 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
# used in `test_torch_compile`
|
||||
_torch_compile_test_ckpt = "google/gemma-2b"
|
||||
|
||||
# used in `test_torch_compile_for_training`
|
||||
_torch_compile_train_cls = GemmaForCausalLM if is_torch_available() else None
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
@@ -808,7 +811,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
prompts = ["Hello I am doing", "Hi today"]
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
|
||||
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map="sequential", torch_dtype=torch.float16)
|
||||
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map=torch_device, torch_dtype=torch.float16)
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
# Dynamic Cache
|
||||
|
||||
@@ -319,6 +319,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
# used in `test_torch_compile`
|
||||
_torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf"
|
||||
|
||||
# used in `test_torch_compile_for_training`
|
||||
_torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = LlamaModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37)
|
||||
@@ -874,7 +877,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
]
|
||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16
|
||||
"meta-llama/Llama-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
|
||||
@@ -677,7 +677,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = MistralForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16
|
||||
"mistralai/Mistral-7B-v0.1", device_map=torch_device, torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
|
||||
@@ -94,6 +94,8 @@ class NemotronModelTest(GemmaModelTest):
|
||||
|
||||
# used in `test_torch_compile`
|
||||
_torch_compile_test_ckpt = "nvidia/nemotron-3-8b-base-4k-hf"
|
||||
# used in `test_torch_compile_for_training`
|
||||
_torch_compile_train_cls = NemotronForCausalLM if is_torch_available() else None
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = NemotronModelTester(self)
|
||||
|
||||
@@ -4937,6 +4937,49 @@ class ModelTesterMixin:
|
||||
for i in range(n_iter):
|
||||
_ = model.generate(**input_ids, do_sample=False)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_torch_compile_for_training(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
if not hasattr(self, "_torch_compile_train_cls"):
|
||||
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_train_cls`.")
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
cls = self._torch_compile_train_cls
|
||||
model = cls(config).to(torch_device)
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||
"attention_mask": torch.tensor(
|
||||
[[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
dtype=torch.int64,
|
||||
device=torch_device,
|
||||
),
|
||||
"position_ids": torch.arange(0, 10, device=torch_device).unsqueeze(0),
|
||||
"labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||
}
|
||||
|
||||
# eager backward
|
||||
set_seed(42)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
params = {name: param.grad.clone().detach().cpu() for name, param in model.named_parameters()}
|
||||
model.zero_grad()
|
||||
del loss
|
||||
|
||||
model = torch.compile(model, fullgraph=True, mode="reduce-overhead")
|
||||
# forward compilation
|
||||
set_seed(42)
|
||||
loss = model(**inputs).loss
|
||||
# backward compilation
|
||||
loss.backward()
|
||||
# check grad matches
|
||||
for name, param in model._orig_mod.named_parameters():
|
||||
torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu # Testing cuda graphs.
|
||||
@require_read_token
|
||||
|
||||
Reference in New Issue
Block a user