[tests] make cuda-only tests device-agnostic (#35607)
* intial commit * remove unrelated files * further remove * Update test_trainer.py * fix style
This commit is contained in:
@@ -1831,7 +1831,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
_ = trainer.train()
|
||||
|
||||
@require_grokadamw
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_grokadamw(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
@@ -1852,7 +1852,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
_ = trainer.train()
|
||||
|
||||
@require_schedulefree
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_schedulefree_adam(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
|
||||
Reference in New Issue
Block a user