[BLIP-2] Improve conversion script (#24854)
* Improve conversion script * Add int8 code example * Update tip * Fix code * Fix code snippet * Add nucleus sampling * More improvements * Address comments * Address comments
This commit is contained in:
@@ -90,7 +90,7 @@ class Blip2VisionConfig(PretrainedConfig):
|
|||||||
image_size=224,
|
image_size=224,
|
||||||
patch_size=14,
|
patch_size=14,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
layer_norm_eps=0.00001,
|
layer_norm_eps=1e-6,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
initializer_range=1e-10,
|
initializer_range=1e-10,
|
||||||
qkv_bias=True,
|
qkv_bias=True,
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
# pip3 install salesforce-lavis
|
# pip3 install salesforce-lavis
|
||||||
# I'm actually installing a slightly modified version: pip3 install git+https://github.com/nielsrogge/LAVIS.git@fix_lavis
|
# I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
|
||||||
|
# to make sure we can compare both original and HF implementation in float32
|
||||||
from lavis.models import load_model_and_preprocess
|
from lavis.models import load_model_and_preprocess
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@@ -37,6 +38,7 @@ from transformers import (
|
|||||||
BlipImageProcessor,
|
BlipImageProcessor,
|
||||||
OPTConfig,
|
OPTConfig,
|
||||||
T5Config,
|
T5Config,
|
||||||
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
|
|
||||||
@@ -145,11 +147,16 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
|
|||||||
|
|
||||||
name, type = model_name_to_original[model_name]
|
name, type = model_name_to_original[model_name]
|
||||||
|
|
||||||
|
# note: this script is tested on 2 GPUs, as models are compared in float32,
|
||||||
|
# which requires quite some memory. Hence loading both on a
|
||||||
|
# separate device is the easiest to compare
|
||||||
|
hf_model_device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
lavis_device = "cuda:1" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
# load original model
|
# load original model
|
||||||
print("Loading original model...")
|
print("Loading original model...")
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
original_model, vis_processors, _ = load_model_and_preprocess(
|
original_model, vis_processors, _ = load_model_and_preprocess(
|
||||||
name=name, model_type=type, is_eval=True, device=device
|
name=name, model_type=type, is_eval=True, device=lavis_device
|
||||||
)
|
)
|
||||||
original_model.eval()
|
original_model.eval()
|
||||||
print("Done!")
|
print("Done!")
|
||||||
@@ -185,61 +192,53 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
|
|||||||
assert unexpected_keys == ["qformer.embeddings.position_ids"]
|
assert unexpected_keys == ["qformer.embeddings.position_ids"]
|
||||||
|
|
||||||
image = load_demo_image()
|
image = load_demo_image()
|
||||||
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(device)
|
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
|
||||||
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(device)
|
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
|
||||||
|
|
||||||
# create processor
|
# create processor
|
||||||
image_processor = BlipImageProcessor(
|
image_processor = BlipImageProcessor(
|
||||||
size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
|
size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
|
||||||
)
|
)
|
||||||
processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
|
processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
|
||||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
|
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)
|
||||||
|
|
||||||
# make sure processor creates exact same pixel values
|
# make sure processor creates exact same pixel values
|
||||||
assert torch.allclose(pixel_values, original_pixel_values)
|
assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))
|
||||||
|
|
||||||
original_model.to(device)
|
original_model.to(lavis_device)
|
||||||
hf_model.to(device)
|
hf_model.to(hf_model_device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "opt" in model_name:
|
if "opt" in model_name:
|
||||||
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
|
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
|
||||||
logits = hf_model(original_pixel_values, input_ids).logits
|
logits = hf_model(pixel_values, input_ids).logits
|
||||||
else:
|
else:
|
||||||
original_logits = original_model(
|
original_logits = original_model(
|
||||||
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
|
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
|
||||||
).logits
|
).logits
|
||||||
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
|
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
|
||||||
logits = hf_model(original_pixel_values, input_ids, labels=labels).logits
|
logits = hf_model(pixel_values, input_ids, labels=labels).logits
|
||||||
|
|
||||||
assert original_logits.shape == logits.shape
|
assert original_logits.shape == logits.shape
|
||||||
print("First values of original logits:", original_logits[0, :3, :3])
|
print("First values of original logits:", original_logits[0, :3, :3])
|
||||||
print("First values of HF logits:", logits[0, :3, :3])
|
print("First values of HF logits:", logits[0, :3, :3])
|
||||||
|
|
||||||
# assert values
|
# assert values
|
||||||
if model_name == "blip2-flan-t5-xl":
|
assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
|
||||||
expected_slice_logits = torch.tensor(
|
|
||||||
[[-41.5850, -4.4440, -8.9922], [-47.4322, -5.9143, -1.7340]], device=device
|
|
||||||
)
|
|
||||||
assert torch.allclose(logits[0, :3, :3], expected_slice_logits, atol=1e-4)
|
|
||||||
elif model_name == "blip2-flan-t5-xl-coco":
|
|
||||||
expected_slice_logits = torch.tensor(
|
|
||||||
[[-57.0109, -9.8967, -12.6280], [-68.6578, -12.7191, -10.5065]], device=device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# cast to same type
|
|
||||||
target_dtype = logits.dtype
|
|
||||||
assert torch.allclose(original_logits.to(target_dtype), logits, atol=1e-2)
|
|
||||||
print("Looks ok!")
|
print("Looks ok!")
|
||||||
|
|
||||||
print("Generating a caption...")
|
print("Generating a caption...")
|
||||||
prompt = ""
|
prompt = "Question: what object is in this image? Answer:"
|
||||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)
|
||||||
|
|
||||||
original_outputs = original_model.generate({"image": original_pixel_values})
|
set_seed(42)
|
||||||
|
|
||||||
|
original_outputs = original_model.generate(
|
||||||
|
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True
|
||||||
|
)
|
||||||
outputs = hf_model.generate(
|
outputs = hf_model.generate(
|
||||||
original_pixel_values,
|
pixel_values,
|
||||||
input_ids,
|
input_ids,
|
||||||
do_sample=False,
|
do_sample=True,
|
||||||
num_beams=5,
|
num_beams=5,
|
||||||
max_length=30,
|
max_length=30,
|
||||||
min_length=1,
|
min_length=1,
|
||||||
@@ -248,10 +247,9 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
|
|||||||
length_penalty=1.0,
|
length_penalty=1.0,
|
||||||
temperature=1,
|
temperature=1,
|
||||||
)
|
)
|
||||||
print("Original generation:", original_outputs)
|
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
|
||||||
prompt_length = input_ids.shape[1]
|
|
||||||
output_text = processor.batch_decode(outputs[:, prompt_length:], skip_special_tokens=True)
|
|
||||||
output_text = [text.strip() for text in output_text]
|
output_text = [text.strip() for text in output_text]
|
||||||
|
print("Original generation:", original_outputs)
|
||||||
print("HF generation:", output_text)
|
print("HF generation:", output_text)
|
||||||
|
|
||||||
if pytorch_dump_folder_path is not None:
|
if pytorch_dump_folder_path is not None:
|
||||||
|
|||||||
@@ -1556,6 +1556,12 @@ class Blip2Model(Blip2PreTrainedModel):
|
|||||||
|
|
||||||
One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
|
One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
|
||||||
the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
|
the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Note that Flan-T5 checkpoints cannot be cast to float16. They are pre-trained using bfloat16.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
""",
|
""",
|
||||||
BLIP_2_START_DOCSTRING,
|
BLIP_2_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
@@ -1687,15 +1693,40 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
|
|||||||
|
|
||||||
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||||
>>> model = Blip2ForConditionalGeneration.from_pretrained(
|
>>> model = Blip2ForConditionalGeneration.from_pretrained(
|
||||||
... "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
|
... "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
|
||||||
... )
|
... ) # doctest: +IGNORE_RESULT
|
||||||
>>> model.to(device) # doctest: +IGNORE_RESULT
|
|
||||||
|
|
||||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
>>> prompt = "Question: how many cats are there? Answer:"
|
>>> prompt = "Question: how many cats are there? Answer:"
|
||||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.float16)
|
||||||
|
|
||||||
|
>>> generated_ids = model.generate(**inputs)
|
||||||
|
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
||||||
|
>>> print(generated_text)
|
||||||
|
two
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that int8 inference is also supported through [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
|
||||||
|
This greatly reduces the amount of memory used by the model while maintaining the same performance.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
||||||
|
>>> model = Blip2ForConditionalGeneration.from_pretrained(
|
||||||
|
... "Salesforce/blip2-flan-t5-xl", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.bfloat16
|
||||||
|
... ) # doctest: +IGNORE_RESULT
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> prompt = "Question: how many cats are there? Answer:"
|
||||||
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
>>> generated_ids = model.generate(**inputs)
|
>>> generated_ids = model.generate(**inputs)
|
||||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ src/transformers/models/blip/image_processing_blip.py
|
|||||||
src/transformers/models/blip/modeling_blip.py
|
src/transformers/models/blip/modeling_blip.py
|
||||||
src/transformers/models/blip/modeling_tf_blip.py
|
src/transformers/models/blip/modeling_tf_blip.py
|
||||||
src/transformers/models/blip/processing_blip.py
|
src/transformers/models/blip/processing_blip.py
|
||||||
|
src/transformers/models/blip_2/modeling_blip_2.py
|
||||||
src/transformers/models/blip_2/processing_blip_2.py
|
src/transformers/models/blip_2/processing_blip_2.py
|
||||||
src/transformers/models/bloom/configuration_bloom.py
|
src/transformers/models/bloom/configuration_bloom.py
|
||||||
src/transformers/models/bloom/tokenization_bloom_fast.py
|
src/transformers/models/bloom/tokenization_bloom_fast.py
|
||||||
|
|||||||
Reference in New Issue
Block a user