From edc1e734bfc01109b8c66881d950ebbda032a6d2 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 13 Feb 2023 16:44:27 +0100 Subject: [PATCH] Fix Blip-2 CI (#21595) * use fp16 * use fp16 * use fp16 * use fp16 --------- Co-authored-by: ydshieh --- tests/models/blip_2/test_modeling_blip_2.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 47bc6933be..40f64e971a 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -768,11 +768,13 @@ def prepare_img(): class Blip2ModelIntegrationTest(unittest.TestCase): def test_inference_opt(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") - model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device) + model = Blip2ForConditionalGeneration.from_pretrained( + "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16 + ).to(torch_device) # prepare image image = prepare_img() - inputs = processor(images=image, return_tensors="pt").to(torch_device) + inputs = processor(images=image, return_tensors="pt").to(torch_device, dtype=torch.float16) predictions = model.generate(**inputs) generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() @@ -783,7 +785,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase): # image and context prompt = "Question: which city is this? Answer:" - inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device) + inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16) predictions = model.generate(**inputs) generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() @@ -797,11 +799,13 @@ class Blip2ModelIntegrationTest(unittest.TestCase): def test_inference_t5(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") - model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device) + model = Blip2ForConditionalGeneration.from_pretrained( + "Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16 + ).to(torch_device) # prepare image image = prepare_img() - inputs = processor(images=image, return_tensors="pt").to(torch_device) + inputs = processor(images=image, return_tensors="pt").to(torch_device, dtype=torch.float16) predictions = model.generate(**inputs) generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() @@ -812,7 +816,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase): # image and context prompt = "Question: which city is this? Answer:" - inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device) + inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16) predictions = model.generate(**inputs) generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()