Fix llava half precision and autocast issues (#29721)
* Ensure input_embeds and image_features are the same dtype in autocast * Fix nans in half precision llava-next and fix autocasting behavior. * Fix styling issues. * fix randn newline instantiation * fix broken slow llava test * Fix llava next init. * fix styling issues * [run-slow]llava,llava_next * fix styling issues
This commit is contained in:
@@ -438,6 +438,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
image_features = self.multi_modal_projector(selected_image_feature)
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
|
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
||||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch Llava-NeXT model."""
|
""" PyTorch Llava-NeXT model."""
|
||||||
|
|
||||||
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -306,8 +307,8 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
|||||||
self.vision_tower = AutoModel.from_config(config.vision_config)
|
self.vision_tower = AutoModel.from_config(config.vision_config)
|
||||||
|
|
||||||
self.multi_modal_projector = LlavaNextMultiModalProjector(config)
|
self.multi_modal_projector = LlavaNextMultiModalProjector(config)
|
||||||
|
embed_std = 1 / math.sqrt(config.text_config.hidden_size)
|
||||||
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size, dtype=self.dtype))
|
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
|
||||||
|
|
||||||
self.vocab_size = config.text_config.vocab_size
|
self.vocab_size = config.text_config.vocab_size
|
||||||
self.language_model = AutoModelForCausalLM.from_config(
|
self.language_model = AutoModelForCausalLM.from_config(
|
||||||
@@ -543,7 +544,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
|||||||
image_feature = torch.cat(
|
image_feature = torch.cat(
|
||||||
(
|
(
|
||||||
image_feature,
|
image_feature,
|
||||||
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
|
self.image_newline[:, None, None]
|
||||||
|
.expand(*image_feature.shape[:-1], 1)
|
||||||
|
.to(image_feature.dtype),
|
||||||
),
|
),
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
@@ -554,6 +557,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
|||||||
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
|
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
|
||||||
new_image_features.append(image_feature)
|
new_image_features.append(image_feature)
|
||||||
image_features = torch.stack(new_image_features, dim=0)
|
image_features = torch.stack(new_image_features, dim=0)
|
||||||
|
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
||||||
|
|
||||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
|
|||||||
@@ -157,6 +157,19 @@ class LlavaVisionText2TextModelTester:
|
|||||||
}
|
}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def create_and_check_llava_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
|
||||||
|
model = LlavaForConditionalGeneration(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||||
|
logits = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
pixel_values=pixel_values.to(torch.bfloat16),
|
||||||
|
return_dict=True,
|
||||||
|
)["logits"]
|
||||||
|
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class LlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
|
class LlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
@@ -225,7 +238,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
def test_small_model_integration_test_llama(self):
|
def test_small_model_integration_test_llama_single(self):
|
||||||
# Let' s make sure we test the preprocessing to replace what is used
|
# Let' s make sure we test the preprocessing to replace what is used
|
||||||
model_id = "llava-hf/llava-1.5-7b-hf"
|
model_id = "llava-hf/llava-1.5-7b-hf"
|
||||||
|
|
||||||
@@ -238,7 +251,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
|
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
|
||||||
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip
|
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Lastly, be respectful of the environment and other visitors, as the pier is a shared space where people can enjoy the view, relax, or engage in recreational activities." # fmt: skip
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
processor.decode(output[0], skip_special_tokens=True),
|
processor.decode(output[0], skip_special_tokens=True),
|
||||||
@@ -267,7 +280,10 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, you', 'USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip
|
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, you', 'USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip
|
||||||
|
|
||||||
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
self.assertEqual(
|
||||||
|
processor.batch_decode(output, skip_special_tokens=True),
|
||||||
|
EXPECTED_DECODED_TEXT,
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
@@ -287,7 +303,10 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
output = model.generate(**inputs, max_new_tokens=20)
|
output = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
|
||||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along', 'USER: \nWhat is this?\nASSISTANT: Cats'] # fmt: skip
|
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along', 'USER: \nWhat is this?\nASSISTANT: Cats'] # fmt: skip
|
||||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
self.assertEqual(
|
||||||
|
self.processor.batch_decode(output, skip_special_tokens=True),
|
||||||
|
EXPECTED_DECODED_TEXT,
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
@@ -314,7 +333,10 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
|
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
|
||||||
|
|
||||||
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
self.assertEqual(
|
||||||
|
processor.batch_decode(output, skip_special_tokens=True),
|
||||||
|
EXPECTED_DECODED_TEXT,
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -342,7 +364,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
EXPECTED_OUTPUT = [
|
EXPECTED_OUTPUT = [
|
||||||
"\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog holding a flower in one",
|
"\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog in one and a ll",
|
||||||
"\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding",
|
"\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding",
|
||||||
"\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama",
|
"\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -27,11 +27,21 @@ from transformers import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
|
from transformers.testing_utils import (
|
||||||
|
require_bitsandbytes,
|
||||||
|
require_torch,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
from ...test_modeling_common import (
|
||||||
|
ModelTesterMixin,
|
||||||
|
_config_zero_init,
|
||||||
|
floats_tensor,
|
||||||
|
ids_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -157,6 +167,39 @@ class LlavaNextVisionText2TextModelTester:
|
|||||||
}
|
}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def create_and_check_llava_next_model_fp16_forward(
|
||||||
|
self, config, input_ids, pixel_values, attention_mask, image_sizes
|
||||||
|
):
|
||||||
|
model = LlavaNextForConditionalGeneration(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.half()
|
||||||
|
model.eval()
|
||||||
|
logits = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
pixel_values=pixel_values.to(torch.bfloat16),
|
||||||
|
return_dict=True,
|
||||||
|
)["logits"]
|
||||||
|
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||||
|
|
||||||
|
def create_and_check_llava_next_model_fp16_autocast_forward(
|
||||||
|
self, config, input_ids, pixel_values, attention_mask, image_sizes
|
||||||
|
):
|
||||||
|
config.torch_dtype = torch.float16
|
||||||
|
model = LlavaNextForConditionalGeneration(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||||
|
logits = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
pixel_values=pixel_values.to(torch.bfloat16),
|
||||||
|
return_dict=True,
|
||||||
|
)["logits"]
|
||||||
|
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
@@ -239,14 +282,20 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
inputs = self.processor(self.prompt, self.image, return_tensors="pt")
|
inputs = self.processor(self.prompt, self.image, return_tensors="pt")
|
||||||
|
|
||||||
# verify inputs against original implementation
|
# verify inputs against original implementation
|
||||||
filepath = hf_hub_download(repo_id="nielsr/test-image", filename="llava_1_6_input_ids.pt", repo_type="dataset")
|
filepath = hf_hub_download(
|
||||||
|
repo_id="nielsr/test-image",
|
||||||
|
filename="llava_1_6_input_ids.pt",
|
||||||
|
repo_type="dataset",
|
||||||
|
)
|
||||||
original_input_ids = torch.load(filepath, map_location="cpu")
|
original_input_ids = torch.load(filepath, map_location="cpu")
|
||||||
# replace -200 by image_token_index (since we use token ID = 32000 for the image token)
|
# replace -200 by image_token_index (since we use token ID = 32000 for the image token)
|
||||||
original_input_ids[original_input_ids == -200] = model.config.image_token_index
|
original_input_ids[original_input_ids == -200] = model.config.image_token_index
|
||||||
assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist()
|
assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist()
|
||||||
|
|
||||||
filepath = hf_hub_download(
|
filepath = hf_hub_download(
|
||||||
repo_id="nielsr/test-image", filename="llava_1_6_pixel_values.pt", repo_type="dataset"
|
repo_id="nielsr/test-image",
|
||||||
|
filename="llava_1_6_pixel_values.pt",
|
||||||
|
repo_type="dataset",
|
||||||
)
|
)
|
||||||
original_pixel_values = torch.load(filepath, map_location="cpu")
|
original_pixel_values = torch.load(filepath, map_location="cpu")
|
||||||
assert torch.allclose(original_pixel_values, inputs.pixel_values.half())
|
assert torch.allclose(original_pixel_values, inputs.pixel_values.half())
|
||||||
@@ -257,7 +306,11 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
output = model(**inputs)
|
output = model(**inputs)
|
||||||
|
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
[[-4.7695, -4.5664, -0.2786], [-10.6250, -10.8906, -2.5254], [-6.7383, -7.2461, -0.6787]],
|
[
|
||||||
|
[-4.7695, -4.5664, -0.2786],
|
||||||
|
[-10.6250, -10.8906, -2.5254],
|
||||||
|
[-6.7383, -7.2461, -0.6787],
|
||||||
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
@@ -282,7 +335,10 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
cats_image = Image.open(requests.get(url, stream=True).raw)
|
cats_image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
inputs = self.processor(
|
inputs = self.processor(
|
||||||
[self.prompt, self.prompt], images=[self.image, cats_image], return_tensors="pt", padding=True
|
[self.prompt, self.prompt],
|
||||||
|
images=[self.image, cats_image],
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
|
|
||||||
# make sure image_sizes are the same
|
# make sure image_sizes are the same
|
||||||
@@ -292,7 +348,10 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
output = model.generate(**inputs, max_new_tokens=20)
|
output = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
|
||||||
EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
|
EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
|
||||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
self.assertEqual(
|
||||||
|
self.processor.batch_decode(output, skip_special_tokens=True),
|
||||||
|
EXPECTED_DECODED_TEXT,
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
|
|||||||
Reference in New Issue
Block a user