Generate: make TF .generate() signature == PT .generate() signature (#21525)
This commit is contained in:
@@ -665,7 +665,7 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[tf.Tensor] = None,
|
inputs: Optional[tf.Tensor] = None,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
logits_processor: Optional[TFLogitsProcessorList] = None,
|
logits_processor: Optional[TFLogitsProcessorList] = None,
|
||||||
seed=None,
|
seed=None,
|
||||||
@@ -686,9 +686,11 @@ class TFGenerationMixin:
|
|||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
input_ids (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*):
|
inputs (`tf.Tensor` of varying shape depending on the modality, *optional*):
|
||||||
The sequence used as a prompt for the generation. If `None` the method initializes it with
|
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
|
||||||
`bos_token_id` and a batch size of 1.
|
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
|
||||||
|
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
|
||||||
|
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
|
||||||
generation_config (`~generation.GenerationConfig`, *optional*):
|
generation_config (`~generation.GenerationConfig`, *optional*):
|
||||||
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
||||||
passed to generate matching the attributes of `generation_config` will override them. If
|
passed to generate matching the attributes of `generation_config` will override them. If
|
||||||
@@ -755,13 +757,13 @@ class TFGenerationMixin:
|
|||||||
self._validate_model_kwargs(model_kwargs.copy())
|
self._validate_model_kwargs(model_kwargs.copy())
|
||||||
|
|
||||||
# 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
|
# 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
|
||||||
if input_ids is not None:
|
if inputs is not None:
|
||||||
if isinstance(input_ids, tf.Tensor) and input_ids.dtype.is_floating:
|
if isinstance(inputs, tf.Tensor) and inputs.dtype.is_floating:
|
||||||
pass
|
pass
|
||||||
elif isinstance(input_ids, np.ndarray) and np.issubdtype(input_ids.dtype, np.floating):
|
elif isinstance(inputs, np.ndarray) and np.issubdtype(inputs.dtype, np.floating):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
input_ids = tf.cast(input_ids, tf.int32)
|
inputs = tf.cast(inputs, tf.int32)
|
||||||
if model_kwargs.get("attention_mask") is not None:
|
if model_kwargs.get("attention_mask") is not None:
|
||||||
model_kwargs["attention_mask"] = tf.cast(model_kwargs["attention_mask"], tf.int32)
|
model_kwargs["attention_mask"] = tf.cast(model_kwargs["attention_mask"], tf.int32)
|
||||||
if "decoder_input_ids" in model_kwargs:
|
if "decoder_input_ids" in model_kwargs:
|
||||||
@@ -800,7 +802,7 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
# 4. Define model inputs
|
# 4. Define model inputs
|
||||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||||
input_ids, generation_config.bos_token_id, model_kwargs
|
inputs, generation_config.bos_token_id, model_kwargs
|
||||||
)
|
)
|
||||||
# inputs_ids now has to be defined and cannot be None anymore
|
# inputs_ids now has to be defined and cannot be None anymore
|
||||||
batch_size = shape_list(inputs_tensor)[0]
|
batch_size = shape_list(inputs_tensor)[0]
|
||||||
|
|||||||
@@ -396,3 +396,93 @@ class GenerationIntegrationTestsMixin:
|
|||||||
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
|
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
|
||||||
|
|
||||||
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores))
|
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores))
|
||||||
|
|
||||||
|
def test_encoder_decoder_generate_attention_mask(self):
|
||||||
|
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||||
|
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||||
|
is_pt = not model_cls.__name__.startswith("TF")
|
||||||
|
|
||||||
|
articles = ["Timberlake", "Jessica Biel, welcome to parenthood among other things"]
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
# need extreme generation values here to force this test
|
||||||
|
# to fail when `attention_mask` is not correctly treated in generate
|
||||||
|
model = model_cls.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bart", max_length=50, num_beams=5, num_return_sequences=5
|
||||||
|
)
|
||||||
|
model.config.eos_token_id = None
|
||||||
|
input_ids = tokenizer(articles[0], return_tensors=return_tensors).input_ids
|
||||||
|
input_ids_batched = tokenizer(articles, padding=True, return_tensors=return_tensors).input_ids
|
||||||
|
if is_pt:
|
||||||
|
model = model.to(torch_device)
|
||||||
|
input_ids = input_ids.to(torch_device)
|
||||||
|
input_ids_batched = input_ids_batched.to(torch_device)
|
||||||
|
|
||||||
|
output_sequences_batched = model.generate(
|
||||||
|
input_ids=input_ids_batched, return_dict_in_generate=True, output_scores=True
|
||||||
|
)
|
||||||
|
output_sequences = model.generate(input_ids=input_ids, return_dict_in_generate=True, output_scores=True)
|
||||||
|
|
||||||
|
batched_out = output_sequences_batched.sequences_scores
|
||||||
|
out = output_sequences.sequences_scores
|
||||||
|
if is_pt:
|
||||||
|
batched_out = batched_out.cpu().numpy()
|
||||||
|
out = out.cpu().numpy()
|
||||||
|
|
||||||
|
diff = np.abs(np.sum(batched_out[:5]) - np.sum(out))
|
||||||
|
self.assertTrue(diff < 1e-4)
|
||||||
|
|
||||||
|
def test_generate_input_ids_as_kwarg(self):
|
||||||
|
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
|
||||||
|
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||||
|
is_pt = not model_cls.__name__.startswith("TF")
|
||||||
|
|
||||||
|
article = """I need input_ids to generate"""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=15)
|
||||||
|
input_ids = tokenizer(article, return_tensors=return_tensors).input_ids
|
||||||
|
if is_pt:
|
||||||
|
model = model.to(torch_device)
|
||||||
|
input_ids = input_ids.to(torch_device)
|
||||||
|
|
||||||
|
output_sequences_kwargs = model.generate(input_ids=input_ids)
|
||||||
|
output_sequences = model.generate(input_ids)
|
||||||
|
if is_pt:
|
||||||
|
output_sequences_kwargs = output_sequences_kwargs.cpu().numpy()
|
||||||
|
output_sequences = output_sequences.cpu().numpy()
|
||||||
|
|
||||||
|
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs))
|
||||||
|
self.assertEqual(output_sequences.shape, (1, 15))
|
||||||
|
|
||||||
|
def test_generate_input_ids_as_encoder_kwarg(self):
|
||||||
|
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||||
|
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||||
|
is_pt = not model_cls.__name__.startswith("TF")
|
||||||
|
|
||||||
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5)
|
||||||
|
model.config.eos_token_id = None
|
||||||
|
input_ids = tokenizer(article, return_tensors=return_tensors).input_ids
|
||||||
|
if is_pt:
|
||||||
|
model = model.to(torch_device)
|
||||||
|
input_ids = input_ids.to(torch_device)
|
||||||
|
|
||||||
|
output_sequences_kwargs = model.generate(input_ids=input_ids)
|
||||||
|
output_sequences = model.generate(input_ids)
|
||||||
|
if is_pt:
|
||||||
|
output_sequences_kwargs = output_sequences_kwargs.cpu().numpy()
|
||||||
|
output_sequences = output_sequences.cpu().numpy()
|
||||||
|
|
||||||
|
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs))
|
||||||
|
self.assertEqual(output_sequences.shape, (1, 5))
|
||||||
|
|
||||||
|
def test_generate_inputs_and_encoder_kwargs(self):
|
||||||
|
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
|
||||||
|
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||||
|
|
||||||
|
article = """I need input_ids to generate"""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10)
|
||||||
|
input_ids = tokenizer(article, return_tensors=return_tensors).input_ids
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
model.generate(input_ids, input_ids=input_ids)
|
||||||
|
|||||||
@@ -2092,43 +2092,8 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
output = generator(prompt, stop_sequence=" number")
|
output = generator(prompt, stop_sequence=" number")
|
||||||
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
|
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
|
||||||
|
|
||||||
def test_encoder_decoder_generate_attention_mask(self):
|
|
||||||
articles = ["Timberlake", "Jessica Biel, welcome to parenthood among other things"]
|
|
||||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
|
||||||
# need extrem generation values here to force this test
|
|
||||||
# to fail when `attention_mask` is not correctly treated in generate
|
|
||||||
model = BartForConditionalGeneration.from_pretrained(
|
|
||||||
"hf-internal-testing/tiny-random-bart", max_length=50, num_beams=5, num_return_sequences=5
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
model.config.eos_token_id = None
|
|
||||||
input_ids = tokenizer(articles[0], return_tensors="pt").input_ids.to(torch_device)
|
|
||||||
input_ids_batched = tokenizer(articles, padding=True, return_tensors="pt").input_ids.to(torch_device)
|
|
||||||
|
|
||||||
output_sequences_batched = model.generate(
|
|
||||||
input_ids=input_ids_batched, return_dict_in_generate=True, output_scores=True
|
|
||||||
)
|
|
||||||
output_sequences = model.generate(input_ids=input_ids, return_dict_in_generate=True, output_scores=True)
|
|
||||||
|
|
||||||
batched_out = output_sequences_batched.sequences_scores
|
|
||||||
out = output_sequences.sequences_scores
|
|
||||||
|
|
||||||
diff = (batched_out[:5].sum() - out.sum()).abs()
|
|
||||||
|
|
||||||
self.assertTrue(diff < 1e-4)
|
|
||||||
|
|
||||||
def test_generate_input_ids_as_kwarg(self):
|
|
||||||
article = """I need input_ids to generate"""
|
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
|
||||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=15).to(torch_device)
|
|
||||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
|
||||||
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
|
|
||||||
output_sequences = model.generate(input_ids).cpu()
|
|
||||||
|
|
||||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
|
||||||
self.assertEqual(output_sequences.shape, (1, 15))
|
|
||||||
|
|
||||||
def test_generate_non_nlp_input_ids_as_kwarg(self):
|
def test_generate_non_nlp_input_ids_as_kwarg(self):
|
||||||
|
# PT-only test: AFAIK there is no non-NLP model architecture in TF that supports `input_ids` as its only input
|
||||||
model = ImageGPTForCausalImageModeling.from_pretrained(
|
model = ImageGPTForCausalImageModeling.from_pretrained(
|
||||||
"hf-internal-testing/tiny-random-imagegpt", max_length=10
|
"hf-internal-testing/tiny-random-imagegpt", max_length=10
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
@@ -2140,28 +2105,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||||
self.assertEqual(output_sequences.shape, (3, 10))
|
self.assertEqual(output_sequences.shape, (3, 10))
|
||||||
|
|
||||||
def test_generate_input_ids_as_encoder_kwarg(self):
|
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
|
||||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
|
||||||
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
model.config.eos_token_id = None
|
|
||||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
|
||||||
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
|
|
||||||
output_sequences = model.generate(input_ids).cpu()
|
|
||||||
|
|
||||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
|
||||||
self.assertEqual(output_sequences.shape, (1, 5))
|
|
||||||
|
|
||||||
def test_generate_inputs_and_encoder_kwargs(self):
|
|
||||||
article = """I need input_ids to generate"""
|
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
|
||||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
|
|
||||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
model.generate(input_ids, input_ids=input_ids)
|
|
||||||
|
|
||||||
def test_generate_too_many_encoder_kwargs(self):
|
def test_generate_too_many_encoder_kwargs(self):
|
||||||
article = """I need input_ids to generate"""
|
article = """I need input_ids to generate"""
|
||||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
|||||||
Reference in New Issue
Block a user