[ImageGPT] Deprecate pixel_values input name to input_ids (#14801)
* [ImageGPT] Deprecate pixel_values input name to input_ids * up * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * correct * finish Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
c4a96cecbc
commit
84ea427f46
@@ -20,7 +20,7 @@ import unittest
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_modeling_common import floats_tensor
|
||||
from .test_modeling_common import floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -31,6 +31,7 @@ if is_torch_available():
|
||||
BartTokenizer,
|
||||
GPT2LMHeadModel,
|
||||
GPT2Tokenizer,
|
||||
ImageGPTForCausalImageModeling,
|
||||
Speech2TextForConditionalGeneration,
|
||||
SpeechEncoderDecoderModel,
|
||||
VisionEncoderDecoderModel,
|
||||
@@ -1766,6 +1767,18 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (1, 15))
|
||||
|
||||
def test_generate_non_nlp_input_ids_as_kwarg(self):
|
||||
model = ImageGPTForCausalImageModeling.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-imagegpt", max_length=10
|
||||
).to(torch_device)
|
||||
input_ids = ids_tensor((3, 5), vocab_size=10)
|
||||
|
||||
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
|
||||
output_sequences = model.generate(input_ids).cpu()
|
||||
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (3, 10))
|
||||
|
||||
def test_generate_input_ids_as_encoder_kwarg(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
|
||||
@@ -314,7 +314,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCa
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
expected_arg_names = ["input_ids"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
|
||||
Reference in New Issue
Block a user