Change in-place operations to out-of-place in LogitsProcessors (#29680)
* change in-place -> out-of-place * add tests * add more tests * naming consistency * fix doctest * forgot min-length processors * empty * Revert "fix doctest" This reverts commit 4772768457f9bc057f1d4d9d67ea94eb7224eb8d. * revert change in docstring * Update tests/generation/test_logits_process.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/generation/test_logits_process.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
b469ebc5cf
commit
fadb053379
@@ -2417,6 +2417,27 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max()
|
||||
self.assertTrue(max_score_diff < 1e-5)
|
||||
|
||||
def test_logits_processor_not_inplace(self):
|
||||
# PT-only test: TF fixes were not made
|
||||
article = "Today a dragon flew over Paris."
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
out = model.generate(input_ids, output_logits=True, output_scores=True, return_dict_in_generate=True)
|
||||
out_with_temp = model.generate(
|
||||
input_ids,
|
||||
temperature=0.5,
|
||||
do_sample=True,
|
||||
output_logits=True,
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
# if no logits processor is used, scores == logits. Otherwise, the processor has to modify the scores
|
||||
self.assertListEqual(out.logits[-1].tolist(), out.scores[-1].tolist())
|
||||
self.assertNotEqual(out_with_temp.logits[-1].tolist(), out_with_temp.scores[-1].tolist())
|
||||
|
||||
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
|
||||
# Has TF equivalent: this test relies on random sampling
|
||||
generation_kwargs = {
|
||||
|
||||
Reference in New Issue
Block a user