Fix llava half precision and autocast issues (#29721)

* Ensure input_embeds and image_features are the same dtype in autocast

* Fix nans in half precision llava-next and fix autocasting behavior.

* Fix styling issues.

* fix randn newline instantiation

* fix broken slow llava test

* Fix llava next init.

* fix styling issues

* [run-slow]llava,llava_next

* fix styling issues
This commit is contained in:
Fraser Mince
2024-05-01 11:49:44 -05:00
committed by GitHub
parent d57ffb487f
commit 5090ea3f68
4 changed files with 102 additions and 16 deletions

View File

@@ -27,11 +27,21 @@ from transformers import (
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
)
if is_torch_available():
@@ -157,6 +167,39 @@ class LlavaNextVisionText2TextModelTester:
}
return config, inputs_dict
def create_and_check_llava_next_model_fp16_forward(
self, config, input_ids, pixel_values, attention_mask, image_sizes
):
model = LlavaNextForConditionalGeneration(config=config)
model.to(torch_device)
model.half()
model.eval()
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
image_sizes=image_sizes,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())
def create_and_check_llava_next_model_fp16_autocast_forward(
self, config, input_ids, pixel_values, attention_mask, image_sizes
):
config.torch_dtype = torch.float16
model = LlavaNextForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
with torch.autocast(device_type="cuda", dtype=torch.float16):
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
image_sizes=image_sizes,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())
@require_torch
class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
@@ -239,14 +282,20 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
inputs = self.processor(self.prompt, self.image, return_tensors="pt")
# verify inputs against original implementation
filepath = hf_hub_download(repo_id="nielsr/test-image", filename="llava_1_6_input_ids.pt", repo_type="dataset")
filepath = hf_hub_download(
repo_id="nielsr/test-image",
filename="llava_1_6_input_ids.pt",
repo_type="dataset",
)
original_input_ids = torch.load(filepath, map_location="cpu")
# replace -200 by image_token_index (since we use token ID = 32000 for the image token)
original_input_ids[original_input_ids == -200] = model.config.image_token_index
assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist()
filepath = hf_hub_download(
repo_id="nielsr/test-image", filename="llava_1_6_pixel_values.pt", repo_type="dataset"
repo_id="nielsr/test-image",
filename="llava_1_6_pixel_values.pt",
repo_type="dataset",
)
original_pixel_values = torch.load(filepath, map_location="cpu")
assert torch.allclose(original_pixel_values, inputs.pixel_values.half())
@@ -257,7 +306,11 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model(**inputs)
expected_slice = torch.tensor(
[[-4.7695, -4.5664, -0.2786], [-10.6250, -10.8906, -2.5254], [-6.7383, -7.2461, -0.6787]],
[
[-4.7695, -4.5664, -0.2786],
[-10.6250, -10.8906, -2.5254],
[-6.7383, -7.2461, -0.6787],
],
dtype=torch.float32,
device=torch_device,
)
@@ -282,7 +335,10 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
cats_image = Image.open(requests.get(url, stream=True).raw)
inputs = self.processor(
[self.prompt, self.prompt], images=[self.image, cats_image], return_tensors="pt", padding=True
[self.prompt, self.prompt],
images=[self.image, cats_image],
return_tensors="pt",
padding=True,
).to(torch_device)
# make sure image_sizes are the same
@@ -292,7 +348,10 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes