[Generate] Make generate multi-modal (#14784)
* finish refactor * refactor * add tests * add more tests * up * finish tests * finish * up * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * improve docstring * fix docs Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
48463ebb33
commit
b18d8534ea
@@ -20,6 +20,8 @@ import unittest
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_modeling_common import floats_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@@ -29,6 +31,9 @@ if is_torch_available():
|
||||
BartTokenizer,
|
||||
GPT2LMHeadModel,
|
||||
GPT2Tokenizer,
|
||||
Speech2TextForConditionalGeneration,
|
||||
SpeechEncoderDecoderModel,
|
||||
VisionEncoderDecoderModel,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from transformers.generation_beam_search import BeamSearchScorer
|
||||
@@ -1724,3 +1729,74 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# cannot generate from `inputs_embeds` for decoder only
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(inputs_embeds=inputs_embeds)
|
||||
|
||||
def test_generate_input_ids_as_kwarg(self):
|
||||
article = """I need input_ids to generate"""
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=15).to(torch_device)
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
|
||||
output_sequences = model.generate(input_ids).cpu()
|
||||
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (1, 15))
|
||||
|
||||
def test_generate_input_ids_as_encoder_kwarg(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5).to(
|
||||
torch_device
|
||||
)
|
||||
model.config.eos_token_id = None
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
|
||||
output_sequences = model.generate(input_ids).cpu()
|
||||
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (1, 5))
|
||||
|
||||
def test_generate_inputs_and_encoder_kwargs(self):
|
||||
article = """I need input_ids to generate"""
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, input_ids=input_ids)
|
||||
|
||||
def test_generate_too_many_encoder_kwargs(self):
|
||||
article = """I need input_ids to generate"""
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids=input_ids, input_values=input_ids)
|
||||
|
||||
def test_generate_input_values_as_encoder_kwarg(self):
|
||||
input_values = floats_tensor((2, 250))
|
||||
model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder")
|
||||
model = model.to(torch_device)
|
||||
output_sequences_kwargs = model.generate(input_values=input_values, max_length=5).cpu()
|
||||
output_sequences = model.generate(input_values, max_length=5).cpu()
|
||||
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (2, 5))
|
||||
|
||||
def test_generate_input_features_as_encoder_kwarg(self):
|
||||
input_features = floats_tensor((3, 20, 24))
|
||||
model = Speech2TextForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-speech_to_text")
|
||||
model = model.to(torch_device)
|
||||
output_sequences_kwargs = model.generate(input_features=input_features, max_length=5).cpu()
|
||||
output_sequences = model.generate(input_features, max_length=5).cpu()
|
||||
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (3, 5))
|
||||
|
||||
def test_generate_pixel_values_as_encoder_kwarg(self):
|
||||
pixel_values = floats_tensor((2, 3, 30, 30))
|
||||
model = VisionEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-vision-encoder-decoder")
|
||||
model = model.to(torch_device)
|
||||
output_sequences_kwargs = model.generate(pixel_values=pixel_values, max_length=5).cpu()
|
||||
output_sequences = model.generate(pixel_values, max_length=5).cpu()
|
||||
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (2, 5))
|
||||
|
||||
Reference in New Issue
Block a user