Fix Blip-2 CI (#21595)
* use fp16 * use fp16 * use fp16 * use fp16 --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -768,11 +768,13 @@ def prepare_img():
|
||||
class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_opt(self):
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device)
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
# prepare image
|
||||
image = prepare_img()
|
||||
inputs = processor(images=image, return_tensors="pt").to(torch_device)
|
||||
inputs = processor(images=image, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
predictions = model.generate(**inputs)
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
@@ -783,7 +785,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# image and context
|
||||
prompt = "Question: which city is this? Answer:"
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
predictions = model.generate(**inputs)
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
@@ -797,11 +799,13 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_inference_t5(self):
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
||||
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device)
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
# prepare image
|
||||
image = prepare_img()
|
||||
inputs = processor(images=image, return_tensors="pt").to(torch_device)
|
||||
inputs = processor(images=image, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
predictions = model.generate(**inputs)
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
@@ -812,7 +816,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# image and context
|
||||
prompt = "Question: which city is this? Answer:"
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
predictions = model.generate(**inputs)
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
Reference in New Issue
Block a user