Fix attn mask ignore logic in training-time trace (#32613)

* fix attn mask logic for training-time trace

* add test

* fix

* fix

* fix

* fix

* fix

* format

* [run-slow] llama

* avoid accelearate

* [run-slow] llama
This commit is contained in:
Longjie Zheng
2024-10-04 13:00:45 -04:00
committed by GitHub
parent 614660fdb9
commit 0d1692a49b
7 changed files with 55 additions and 5 deletions

View File

@@ -94,6 +94,8 @@ class NemotronModelTest(GemmaModelTest):
# used in `test_torch_compile`
_torch_compile_test_ckpt = "nvidia/nemotron-3-8b-base-4k-hf"
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = NemotronForCausalLM if is_torch_available() else None
def setUp(self):
self.model_tester = NemotronModelTester(self)