LLaVaNeXT: pad on right if training (#32134)

* pad on right if training

* docs

* add tests
This commit is contained in:
Raushan Turganbay
2024-07-23 10:23:55 +05:00
committed by GitHub
parent 251a2409c6
commit 3aefb4ec7f
6 changed files with 90 additions and 10 deletions

View File

@@ -124,7 +124,7 @@ class LlavaNextVideoVisionText2TextModelTester:
self.batch_size = 3
self.num_channels = 3
self.image_size = 30
self.encoder_seq_length = 468
self.encoder_seq_length = 469
self.image_grid_pinpoints = [[32, 32]]
def get_config(self):
@@ -166,9 +166,7 @@ class LlavaNextVideoVisionText2TextModelTester:
def prepare_config_and_inputs_for_common(self):
config, pixel_values, pixel_values_videos = self.prepare_config_and_inputs()
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
# make attention mask left-padded to avoid issues with "model has no attribute padding_side"
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
attention_mask[:, :1] = 0
# we are giving 3 images and videos let's make sure we pass in 3 special tokens
input_ids[:, 1] = config.image_token_index
input_ids[:, 2] = config.video_token_index
@@ -453,3 +451,39 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
self.processor.decode(output_batched[0], skip_special_tokens=True),
self.processor.decode(output_single[0], skip_special_tokens=True),
)
@slow
@require_bitsandbytes
def test_padding_side_when_merging_inputs(self):
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
"llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True
)
inputs_batched = self.processor(
[self.prompt_video, self.prompt_image],
images=[self.image],
videos=[self.video],
return_tensors="pt",
padding=True,
).to(torch_device)
# model is in eval mode by default so we should get pad on the left side
# we can check the first hidden-states (aka inputs embeds)
# the first element was lo-res image and we expect the first 1482 tokens to be all pads
output_eval = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_eval.hidden_states[0][0, :1482, ...] == 0).all().item())
# otherwise padding is on the right side, so it's last 1482 tokens
self.processor.padding_side = "right"
inputs_batched = self.processor(
[self.prompt_video, self.prompt_image],
images=[self.image],
videos=[self.video],
return_tensors="pt",
padding=True,
).to(torch_device)
model.train()
with torch.no_grad():
output_train = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_train.hidden_states[0][0, -1482:, ...] == 0).all().item())