Fix llava half precision and autocast issues (#29721)
* Ensure input_embeds and image_features are the same dtype in autocast * Fix nans in half precision llava-next and fix autocasting behavior. * Fix styling issues. * fix randn newline instantiation * fix broken slow llava test * Fix llava next init. * fix styling issues * [run-slow]llava,llava_next * fix styling issues
This commit is contained in:
@@ -27,11 +27,21 @@ from transformers import (
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
_config_zero_init,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -157,6 +167,39 @@ class LlavaNextVisionText2TextModelTester:
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_llava_next_model_fp16_forward(
|
||||
self, config, input_ids, pixel_values, attention_mask, image_sizes
|
||||
):
|
||||
model = LlavaNextForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
model.eval()
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
image_sizes=image_sizes,
|
||||
pixel_values=pixel_values.to(torch.bfloat16),
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
def create_and_check_llava_next_model_fp16_autocast_forward(
|
||||
self, config, input_ids, pixel_values, attention_mask, image_sizes
|
||||
):
|
||||
config.torch_dtype = torch.float16
|
||||
model = LlavaNextForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
image_sizes=image_sizes,
|
||||
pixel_values=pixel_values.to(torch.bfloat16),
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
|
||||
@require_torch
|
||||
class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
@@ -239,14 +282,20 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
inputs = self.processor(self.prompt, self.image, return_tensors="pt")
|
||||
|
||||
# verify inputs against original implementation
|
||||
filepath = hf_hub_download(repo_id="nielsr/test-image", filename="llava_1_6_input_ids.pt", repo_type="dataset")
|
||||
filepath = hf_hub_download(
|
||||
repo_id="nielsr/test-image",
|
||||
filename="llava_1_6_input_ids.pt",
|
||||
repo_type="dataset",
|
||||
)
|
||||
original_input_ids = torch.load(filepath, map_location="cpu")
|
||||
# replace -200 by image_token_index (since we use token ID = 32000 for the image token)
|
||||
original_input_ids[original_input_ids == -200] = model.config.image_token_index
|
||||
assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist()
|
||||
|
||||
filepath = hf_hub_download(
|
||||
repo_id="nielsr/test-image", filename="llava_1_6_pixel_values.pt", repo_type="dataset"
|
||||
repo_id="nielsr/test-image",
|
||||
filename="llava_1_6_pixel_values.pt",
|
||||
repo_type="dataset",
|
||||
)
|
||||
original_pixel_values = torch.load(filepath, map_location="cpu")
|
||||
assert torch.allclose(original_pixel_values, inputs.pixel_values.half())
|
||||
@@ -257,7 +306,11 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
output = model(**inputs)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-4.7695, -4.5664, -0.2786], [-10.6250, -10.8906, -2.5254], [-6.7383, -7.2461, -0.6787]],
|
||||
[
|
||||
[-4.7695, -4.5664, -0.2786],
|
||||
[-10.6250, -10.8906, -2.5254],
|
||||
[-6.7383, -7.2461, -0.6787],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=torch_device,
|
||||
)
|
||||
@@ -282,7 +335,10 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
cats_image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = self.processor(
|
||||
[self.prompt, self.prompt], images=[self.image, cats_image], return_tensors="pt", padding=True
|
||||
[self.prompt, self.prompt],
|
||||
images=[self.image, cats_image],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device)
|
||||
|
||||
# make sure image_sizes are the same
|
||||
@@ -292,7 +348,10 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
|
||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
self.assertEqual(
|
||||
self.processor.batch_decode(output, skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
|
||||
Reference in New Issue
Block a user