Generate: handle text conditioning with multimodal encoder-decoder models (#22748)

This commit is contained in:
Joao Gante
2023-04-13 19:51:13 +01:00
committed by GitHub
parent 90ce374d14
commit 9dfd6a4baa
6 changed files with 123 additions and 66 deletions

View File

@@ -1892,8 +1892,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
max_length = 20
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
@@ -1919,8 +1921,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
max_length = 20
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
@@ -1949,8 +1953,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
@@ -1982,8 +1988,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
input_ids = input_ids.expand(6, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
@@ -2021,8 +2029,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# Greedy
input_ids = input_ids.expand(6, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)