[tests] make 2 tests device-agnostic (#30008)

add torch device
This commit is contained in:
Fanli Lin
2024-04-10 20:46:39 +08:00
committed by GitHub
parent bb76f81e40
commit 185463784e
2 changed files with 3 additions and 3 deletions

View File

@@ -776,7 +776,7 @@ class ModelUtilsTest(TestCasePlus):
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
inputs = tokenizer("Hello, my name is", return_tensors="pt")
output = model.generate(inputs["input_ids"].to(0))
output = model.generate(inputs["input_ids"].to(f"{torch_device}:0"))
text_output = tokenizer.decode(output[0].tolist())
self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")