@@ -992,7 +992,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# prepare image
|
# prepare image
|
||||||
image = prepare_img()
|
image = prepare_img()
|
||||||
inputs = processor(images=image, return_tensors="pt").to(0, dtype=torch.float16)
|
inputs = processor(images=image, return_tensors="pt").to(f"{torch_device}:0", dtype=torch.float16)
|
||||||
|
|
||||||
predictions = model.generate(**inputs)
|
predictions = model.generate(**inputs)
|
||||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||||
@@ -1003,7 +1003,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# image and context
|
# image and context
|
||||||
prompt = "Question: which city is this? Answer:"
|
prompt = "Question: which city is this? Answer:"
|
||||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(0, dtype=torch.float16)
|
inputs = processor(images=image, text=prompt, return_tensors="pt").to(f"{torch_device}:0", dtype=torch.float16)
|
||||||
|
|
||||||
predictions = model.generate(**inputs)
|
predictions = model.generate(**inputs)
|
||||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||||
|
|||||||
@@ -776,7 +776,7 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||||
inputs = tokenizer("Hello, my name is", return_tensors="pt")
|
inputs = tokenizer("Hello, my name is", return_tensors="pt")
|
||||||
output = model.generate(inputs["input_ids"].to(0))
|
output = model.generate(inputs["input_ids"].to(f"{torch_device}:0"))
|
||||||
|
|
||||||
text_output = tokenizer.decode(output[0].tolist())
|
text_output = tokenizer.decode(output[0].tolist())
|
||||||
self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
|
self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
|
||||||
|
|||||||
Reference in New Issue
Block a user