Generate: handle text conditioning with multimodal encoder-decoder models (#22748)
This commit is contained in:
@@ -1892,8 +1892,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
max_length = 20
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids.shape[0],
|
||||
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=input_ids.shape[0],
|
||||
model_input_name=bart_model.main_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
@@ -1919,8 +1921,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
max_length = 20
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids.shape[0],
|
||||
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=input_ids.shape[0],
|
||||
model_input_name=bart_model.main_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
@@ -1949,8 +1953,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids.shape[0],
|
||||
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=input_ids.shape[0],
|
||||
model_input_name=bart_model.main_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
@@ -1982,8 +1988,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
input_ids = input_ids.expand(6, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids.shape[0],
|
||||
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=input_ids.shape[0],
|
||||
model_input_name=bart_model.main_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
@@ -2021,8 +2029,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
# Greedy
|
||||
input_ids = input_ids.expand(6, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids.shape[0],
|
||||
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=input_ids.shape[0],
|
||||
model_input_name=bart_model.main_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user