From 0d1692a49bc1d70d72c99ac814773bcc2d3a98be Mon Sep 17 00:00:00 2001 From: Longjie Zheng <32992656+zhenglongjiepheonix@users.noreply.github.com> Date: Fri, 4 Oct 2024 13:00:45 -0400 Subject: [PATCH] Fix attn mask ignore logic in training-time trace (#32613) * fix attn mask logic for training-time trace * add test * fix * fix * fix * fix * fix * format * [run-slow] llama * avoid accelearate * [run-slow] llama --- src/transformers/cache_utils.py | 1 - src/transformers/modeling_attn_mask_utils.py | 2 +- tests/models/gemma/test_modeling_gemma.py | 5 ++- tests/models/llama/test_modeling_llama.py | 5 ++- tests/models/mistral/test_modeling_mistral.py | 2 +- .../models/nemotron/test_modeling_nemotron.py | 2 + tests/test_modeling_common.py | 43 +++++++++++++++++++ 7 files changed, 55 insertions(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 223eda10a9..0dbf0cc682 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1284,7 +1284,6 @@ class SlidingWindowCache(StaticCache): max_batch_size: Optional[int] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: - super().__init__() if not hasattr(config, "sliding_window") or config.sliding_window is None: raise ValueError( "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 08eeaf9765..4319c021cb 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -281,7 +281,7 @@ class AttentionMaskConverter: elif sliding_window is None or key_value_length < sliding_window: if len(attention_mask.shape) == 4: return False - elif (is_training or not is_tracing) and torch.all(attention_mask == 1): + elif not is_tracing and torch.all(attention_mask == 1): if query_length == 1 or key_value_length == query_length: # For query_length == 1, causal attention and bi-directional attention are the same. ignore_causal_mask = True diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 6422133d75..67828259f4 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -321,6 +321,9 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # used in `test_torch_compile` _torch_compile_test_ckpt = "google/gemma-2b" + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = GemmaForCausalLM if is_torch_available() else None + # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name @@ -808,7 +811,7 @@ class GemmaIntegrationTest(unittest.TestCase): prompts = ["Hello I am doing", "Hi today"] tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="", padding_side="right") - model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map="sequential", torch_dtype=torch.float16) + model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map=torch_device, torch_dtype=torch.float16) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) # Dynamic Cache diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 6b273bce7a..3a103f3efa 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -319,6 +319,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # used in `test_torch_compile` _torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf" + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None + def setUp(self): self.model_tester = LlamaModelTester(self) self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37) @@ -874,7 +877,7 @@ class LlamaIntegrationTest(unittest.TestCase): ] tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="right") model = LlamaForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16 + "meta-llama/Llama-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16 ) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 88140b1a20..01dd303095 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -677,7 +677,7 @@ class MistralIntegrationTest(unittest.TestCase): tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) tokenizer.pad_token = tokenizer.eos_token model = MistralForCausalLM.from_pretrained( - "mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16 + "mistralai/Mistral-7B-v0.1", device_map=torch_device, torch_dtype=torch.float16 ) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) diff --git a/tests/models/nemotron/test_modeling_nemotron.py b/tests/models/nemotron/test_modeling_nemotron.py index 4f8f4cc77f..13adfe1e57 100644 --- a/tests/models/nemotron/test_modeling_nemotron.py +++ b/tests/models/nemotron/test_modeling_nemotron.py @@ -94,6 +94,8 @@ class NemotronModelTest(GemmaModelTest): # used in `test_torch_compile` _torch_compile_test_ckpt = "nvidia/nemotron-3-8b-base-4k-hf" + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = NemotronForCausalLM if is_torch_available() else None def setUp(self): self.model_tester = NemotronModelTester(self) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 66b4e25a45..1f5b232f0d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4937,6 +4937,49 @@ class ModelTesterMixin: for i in range(n_iter): _ = model.generate(**input_ids, do_sample=False) + @slow + @require_torch_gpu + def test_torch_compile_for_training(self): + if version.parse(torch.__version__) < version.parse("2.3"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + if not hasattr(self, "_torch_compile_train_cls"): + self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_train_cls`.") + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + cls = self._torch_compile_train_cls + model = cls(config).to(torch_device) + + inputs = { + "input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device), + "attention_mask": torch.tensor( + [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], + dtype=torch.int64, + device=torch_device, + ), + "position_ids": torch.arange(0, 10, device=torch_device).unsqueeze(0), + "labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device), + } + + # eager backward + set_seed(42) + loss = model(**inputs).loss + loss.backward() + + params = {name: param.grad.clone().detach().cpu() for name, param in model.named_parameters()} + model.zero_grad() + del loss + + model = torch.compile(model, fullgraph=True, mode="reduce-overhead") + # forward compilation + set_seed(42) + loss = model(**inputs).loss + # backward compilation + loss.backward() + # check grad matches + for name, param in model._orig_mod.named_parameters(): + torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4) + @slow @require_torch_gpu # Testing cuda graphs. @require_read_token