Change in-place operations to out-of-place in LogitsProcessors (#29680)

* change in-place -> out-of-place

* add tests

* add more tests

* naming consistency

* fix doctest

* forgot min-length processors

* empty

* Revert "fix doctest"

This reverts commit 4772768457f9bc057f1d4d9d67ea94eb7224eb8d.

* revert change in docstring

* Update tests/generation/test_logits_process.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/generation/test_logits_process.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-03-21 21:37:33 +05:00
committed by GitHub
parent b469ebc5cf
commit fadb053379
3 changed files with 226 additions and 118 deletions

View File

@@ -2417,6 +2417,27 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max()
self.assertTrue(max_score_diff < 1e-5)
def test_logits_processor_not_inplace(self):
# PT-only test: TF fixes were not made
article = "Today a dragon flew over Paris."
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
out = model.generate(input_ids, output_logits=True, output_scores=True, return_dict_in_generate=True)
out_with_temp = model.generate(
input_ids,
temperature=0.5,
do_sample=True,
output_logits=True,
output_scores=True,
return_dict_in_generate=True,
)
# if no logits processor is used, scores == logits. Otherwise, the processor has to modify the scores
self.assertListEqual(out.logits[-1].tolist(), out.scores[-1].tolist())
self.assertNotEqual(out_with_temp.logits[-1].tolist(), out_with_temp.scores[-1].tolist())
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
# Has TF equivalent: this test relies on random sampling
generation_kwargs = {