From a582cfce3cff74a6f6e995843a6f2f9085680e1f Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 30 Jan 2023 10:37:56 +0100 Subject: [PATCH] Fix `GitModelIntegrationTest.test_batched_generation` device issue (#21362) fix Co-authored-by: ydshieh --- tests/models/git/test_modeling_git.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/models/git/test_modeling_git.py b/tests/models/git/test_modeling_git.py index e399ddea57..40e435056b 100644 --- a/tests/models/git/test_modeling_git.py +++ b/tests/models/git/test_modeling_git.py @@ -508,9 +508,8 @@ class GitModelIntegrationTest(unittest.TestCase): # we have to prepare `input_ids` with the same batch size as `pixel_values` start_token_id = model.config.bos_token_id - generated_ids = model.generate( - pixel_values=pixel_values, input_ids=torch.tensor([[start_token_id], [start_token_id]]), max_length=50 - ) + input_ids = torch.tensor([[start_token_id], [start_token_id]], device=torch_device) + generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50) generated_captions = processor.batch_decode(generated_ids, skip_special_tokens=True) self.assertEquals(generated_captions, ["two cats sleeping on a pink blanket next to remotes."] * 2)