Fix attn mask ignore logic in training-time trace (#32613)
* fix attn mask logic for training-time trace * add test * fix * fix * fix * fix * fix * format * [run-slow] llama * avoid accelearate * [run-slow] llama
This commit is contained in:
@@ -1284,7 +1284,6 @@ class SlidingWindowCache(StaticCache):
|
|||||||
max_batch_size: Optional[int] = None,
|
max_batch_size: Optional[int] = None,
|
||||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
|
||||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
||||||
|
|||||||
@@ -281,7 +281,7 @@ class AttentionMaskConverter:
|
|||||||
elif sliding_window is None or key_value_length < sliding_window:
|
elif sliding_window is None or key_value_length < sliding_window:
|
||||||
if len(attention_mask.shape) == 4:
|
if len(attention_mask.shape) == 4:
|
||||||
return False
|
return False
|
||||||
elif (is_training or not is_tracing) and torch.all(attention_mask == 1):
|
elif not is_tracing and torch.all(attention_mask == 1):
|
||||||
if query_length == 1 or key_value_length == query_length:
|
if query_length == 1 or key_value_length == query_length:
|
||||||
# For query_length == 1, causal attention and bi-directional attention are the same.
|
# For query_length == 1, causal attention and bi-directional attention are the same.
|
||||||
ignore_causal_mask = True
|
ignore_causal_mask = True
|
||||||
|
|||||||
@@ -321,6 +321,9 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
# used in `test_torch_compile`
|
# used in `test_torch_compile`
|
||||||
_torch_compile_test_ckpt = "google/gemma-2b"
|
_torch_compile_test_ckpt = "google/gemma-2b"
|
||||||
|
|
||||||
|
# used in `test_torch_compile_for_training`
|
||||||
|
_torch_compile_train_cls = GemmaForCausalLM if is_torch_available() else None
|
||||||
|
|
||||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||||
def is_pipeline_test_to_skip(
|
def is_pipeline_test_to_skip(
|
||||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||||
@@ -808,7 +811,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
prompts = ["Hello I am doing", "Hi today"]
|
prompts = ["Hello I am doing", "Hi today"]
|
||||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
|
||||||
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map="sequential", torch_dtype=torch.float16)
|
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map=torch_device, torch_dtype=torch.float16)
|
||||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||||
|
|
||||||
# Dynamic Cache
|
# Dynamic Cache
|
||||||
|
|||||||
@@ -319,6 +319,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
# used in `test_torch_compile`
|
# used in `test_torch_compile`
|
||||||
_torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf"
|
_torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf"
|
||||||
|
|
||||||
|
# used in `test_torch_compile_for_training`
|
||||||
|
_torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = LlamaModelTester(self)
|
self.model_tester = LlamaModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37)
|
||||||
@@ -874,7 +877,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
|
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
"meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16
|
"meta-llama/Llama-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16
|
||||||
)
|
)
|
||||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||||
|
|
||||||
|
|||||||
@@ -677,7 +677,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
model = MistralForCausalLM.from_pretrained(
|
model = MistralForCausalLM.from_pretrained(
|
||||||
"mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16
|
"mistralai/Mistral-7B-v0.1", device_map=torch_device, torch_dtype=torch.float16
|
||||||
)
|
)
|
||||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||||
|
|
||||||
|
|||||||
@@ -94,6 +94,8 @@ class NemotronModelTest(GemmaModelTest):
|
|||||||
|
|
||||||
# used in `test_torch_compile`
|
# used in `test_torch_compile`
|
||||||
_torch_compile_test_ckpt = "nvidia/nemotron-3-8b-base-4k-hf"
|
_torch_compile_test_ckpt = "nvidia/nemotron-3-8b-base-4k-hf"
|
||||||
|
# used in `test_torch_compile_for_training`
|
||||||
|
_torch_compile_train_cls = NemotronForCausalLM if is_torch_available() else None
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = NemotronModelTester(self)
|
self.model_tester = NemotronModelTester(self)
|
||||||
|
|||||||
@@ -4937,6 +4937,49 @@ class ModelTesterMixin:
|
|||||||
for i in range(n_iter):
|
for i in range(n_iter):
|
||||||
_ = model.generate(**input_ids, do_sample=False)
|
_ = model.generate(**input_ids, do_sample=False)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_torch_compile_for_training(self):
|
||||||
|
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||||
|
|
||||||
|
if not hasattr(self, "_torch_compile_train_cls"):
|
||||||
|
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_train_cls`.")
|
||||||
|
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
cls = self._torch_compile_train_cls
|
||||||
|
model = cls(config).to(torch_device)
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
"input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||||
|
"attention_mask": torch.tensor(
|
||||||
|
[[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=torch_device,
|
||||||
|
),
|
||||||
|
"position_ids": torch.arange(0, 10, device=torch_device).unsqueeze(0),
|
||||||
|
"labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||||
|
}
|
||||||
|
|
||||||
|
# eager backward
|
||||||
|
set_seed(42)
|
||||||
|
loss = model(**inputs).loss
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
params = {name: param.grad.clone().detach().cpu() for name, param in model.named_parameters()}
|
||||||
|
model.zero_grad()
|
||||||
|
del loss
|
||||||
|
|
||||||
|
model = torch.compile(model, fullgraph=True, mode="reduce-overhead")
|
||||||
|
# forward compilation
|
||||||
|
set_seed(42)
|
||||||
|
loss = model(**inputs).loss
|
||||||
|
# backward compilation
|
||||||
|
loss.backward()
|
||||||
|
# check grad matches
|
||||||
|
for name, param in model._orig_mod.named_parameters():
|
||||||
|
torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu # Testing cuda graphs.
|
@require_torch_gpu # Testing cuda graphs.
|
||||||
@require_read_token
|
@require_read_token
|
||||||
|
|||||||
Reference in New Issue
Block a user