@@ -992,7 +992,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# prepare image
|
||||
image = prepare_img()
|
||||
inputs = processor(images=image, return_tensors="pt").to(0, dtype=torch.float16)
|
||||
inputs = processor(images=image, return_tensors="pt").to(f"{torch_device}:0", dtype=torch.float16)
|
||||
|
||||
predictions = model.generate(**inputs)
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
@@ -1003,7 +1003,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# image and context
|
||||
prompt = "Question: which city is this? Answer:"
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(0, dtype=torch.float16)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(f"{torch_device}:0", dtype=torch.float16)
|
||||
|
||||
predictions = model.generate(**inputs)
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
Reference in New Issue
Block a user