[Generate] Correct input_ids detection (#14815)
* [Generate] Correct input_ids detection * correct
This commit is contained in:
committed by
GitHub
parent
bdbe3df869
commit
95119ad7b0
@@ -457,7 +457,7 @@ class GenerationMixin:
|
||||
pad_token_id: int,
|
||||
eos_token_id: int,
|
||||
) -> torch.LongTensor:
|
||||
is_input_ids = isinstance(inputs, torch.LongTensor) and len(inputs.shape) == 2
|
||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
|
||||
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
|
||||
(eos_token_id is not None) and (pad_token_id != eos_token_id)
|
||||
|
||||
@@ -1719,6 +1719,31 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# make sure model generated correctly until `max_length`
|
||||
self.assertEqual(output_sequences.shape, (1, 5))
|
||||
|
||||
def test_encoder_decoder_generate_attention_mask(self):
|
||||
articles = ["Timberlake", "Jessica Biel, welcome to parenthood among other things"]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
# need extrem generation values here to force this test
|
||||
# to fail when `attention_mask` is not correctly treated in generate
|
||||
model = BartForConditionalGeneration.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bart", max_length=50, num_beams=5, num_return_sequences=5
|
||||
).to(torch_device)
|
||||
|
||||
model.config.eos_token_id = None
|
||||
input_ids = tokenizer(articles[0], return_tensors="pt").input_ids.to(torch_device)
|
||||
input_ids_batched = tokenizer(articles, padding=True, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
output_sequences_batched = model.generate(
|
||||
input_ids=input_ids_batched, return_dict_in_generate=True, output_scores=True
|
||||
)
|
||||
output_sequences = model.generate(input_ids=input_ids, return_dict_in_generate=True, output_scores=True)
|
||||
|
||||
batched_out = output_sequences_batched.sequences_scores
|
||||
out = output_sequences.sequences_scores
|
||||
|
||||
diff = (batched_out[:5].sum() - out.sum()).abs()
|
||||
|
||||
self.assertTrue(diff < 1e-4)
|
||||
|
||||
def test_decoder_generate_with_inputs_embeds(self):
|
||||
article = """I need input_ids to generate"""
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
|
||||
Reference in New Issue
Block a user