[ImageGPT] Deprecate pixel_values input name to input_ids (#14801)

* [ImageGPT] Deprecate pixel_values input name to input_ids

* up

* Apply suggestions from code review

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* correct

* finish

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2021-12-17 20:05:22 +01:00
committed by GitHub
parent c4a96cecbc
commit 84ea427f46
3 changed files with 93 additions and 33 deletions

View File

@@ -20,7 +20,7 @@ import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_modeling_common import floats_tensor
from .test_modeling_common import floats_tensor, ids_tensor
if is_torch_available():
@@ -31,6 +31,7 @@ if is_torch_available():
BartTokenizer,
GPT2LMHeadModel,
GPT2Tokenizer,
ImageGPTForCausalImageModeling,
Speech2TextForConditionalGeneration,
SpeechEncoderDecoderModel,
VisionEncoderDecoderModel,
@@ -1766,6 +1767,18 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (1, 15))
def test_generate_non_nlp_input_ids_as_kwarg(self):
model = ImageGPTForCausalImageModeling.from_pretrained(
"hf-internal-testing/tiny-random-imagegpt", max_length=10
).to(torch_device)
input_ids = ids_tensor((3, 5), vocab_size=10)
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
output_sequences = model.generate(input_ids).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (3, 10))
def test_generate_input_ids_as_encoder_kwarg(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")

View File

@@ -314,7 +314,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCa
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
expected_arg_names = ["input_ids"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_resize_tokens_embeddings(self):