Generate: handle text conditioning with multimodal encoder-decoder models (#22748)
This commit is contained in:
@@ -94,8 +94,8 @@ class GenerationIntegrationTestsMixin:
|
||||
|
||||
# Decoder only call
|
||||
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
|
||||
# 29 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 32])
|
||||
# 1 BOS + 29 (input length) + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 33])
|
||||
|
||||
# Encoder decoder call > 20
|
||||
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
|
||||
@@ -658,3 +658,31 @@ class GenerationIntegrationTestsMixin:
|
||||
[token == model.config.pad_token_id for token in generated_tokens[0][expectation:]]
|
||||
)
|
||||
self.assertTrue(unpadded_correct_condition or padded_correct_condition)
|
||||
|
||||
def test_generate_vision2text_conditioning(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForVision2Seq"]
|
||||
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
|
||||
create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
pixel_values = floats_tensor((2, 3, 30, 30))
|
||||
conditioning_input = create_tensor_fn([[10], [10]]) # this should be the 2nd output token, after the BOS token
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2")
|
||||
if is_pt:
|
||||
pixel_values = pixel_values.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
conditioning_input = conditioning_input.to(torch_device)
|
||||
|
||||
# we can condition on decoder_input_ids (expected decoder input) and input_ids (which we pipe internally as
|
||||
# decoder_input_ids, if the encoder is not a model with text input)
|
||||
output_sequences_decoder_input_ids = model.generate(
|
||||
pixel_values, max_length=5, decoder_input_ids=conditioning_input
|
||||
)
|
||||
output_sequences_input_ids = model.generate(pixel_values, max_length=5, input_ids=conditioning_input)
|
||||
if is_pt:
|
||||
output_sequences_decoder_input_ids = output_sequences_decoder_input_ids.cpu().numpy()
|
||||
output_sequences_input_ids = output_sequences_input_ids.cpu().numpy()
|
||||
conditioning_input = conditioning_input.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
|
||||
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))
|
||||
|
||||
Reference in New Issue
Block a user