Generate: handle text conditioning with multimodal encoder-decoder models (#22748)

This commit is contained in:
Joao Gante
2023-04-13 19:51:13 +01:00
committed by GitHub
parent 90ce374d14
commit 9dfd6a4baa
6 changed files with 123 additions and 66 deletions

View File

@@ -94,8 +94,8 @@ class GenerationIntegrationTestsMixin:
# Decoder only call
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
# 29 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 32])
# 1 BOS + 29 (input length) + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 33])
# Encoder decoder call > 20
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
@@ -658,3 +658,31 @@ class GenerationIntegrationTestsMixin:
[token == model.config.pad_token_id for token in generated_tokens[0][expectation:]]
)
self.assertTrue(unpadded_correct_condition or padded_correct_condition)
def test_generate_vision2text_conditioning(self):
model_cls = self.framework_dependent_parameters["AutoModelForVision2Seq"]
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"]
is_pt = not model_cls.__name__.startswith("TF")
pixel_values = floats_tensor((2, 3, 30, 30))
conditioning_input = create_tensor_fn([[10], [10]]) # this should be the 2nd output token, after the BOS token
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2")
if is_pt:
pixel_values = pixel_values.to(torch_device)
model = model.to(torch_device)
conditioning_input = conditioning_input.to(torch_device)
# we can condition on decoder_input_ids (expected decoder input) and input_ids (which we pipe internally as
# decoder_input_ids, if the encoder is not a model with text input)
output_sequences_decoder_input_ids = model.generate(
pixel_values, max_length=5, decoder_input_ids=conditioning_input
)
output_sequences_input_ids = model.generate(pixel_values, max_length=5, input_ids=conditioning_input)
if is_pt:
output_sequences_decoder_input_ids = output_sequences_decoder_input_ids.cpu().numpy()
output_sequences_input_ids = output_sequences_input_ids.cpu().numpy()
conditioning_input = conditioning_input.cpu().numpy()
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))

View File

@@ -1892,8 +1892,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
max_length = 20
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
@@ -1919,8 +1921,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
max_length = 20
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
@@ -1949,8 +1953,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
@@ -1982,8 +1988,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
input_ids = input_ids.expand(6, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
@@ -2021,8 +2029,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# Greedy
input_ids = input_ids.expand(6, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)