Add soft length regulation for sequence generation (#15245)

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix test config, fix formatting

* fix rag integration, fix docstyling

* fix wrong docstring

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* change test according to new param

* fix formatting

* fix test case

* fix doc style

* move start_length calculation to Logitprocessor

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix rag integration, fix docstyling

* fix test config, fix formatting

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* add possibility to softly regulate length when using sampling method in model.generate() function

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* remove unused import

* fix small errors

* fix test

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix test config, fix formatting

* fix rag integration, fix docstyling

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* change test according to new param

* fix test case

* move start_length calculation to Logitprocessor

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix rag integration, fix docstyling

* fix test config, fix formatting

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix test config, fix formatting

* fix rag integration, fix docstyling

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix rag integration, fix docstyling

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* fix small errors

* Update src/transformers/generation_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/generation_utils.py

* Update src/transformers/generation_utils.py

* fix docstring, add type ind model rag

* fix docstrings

* introduce seq_length variable for cleaner code

* fix black formatting

* add input_ids_seq_length to modeling_rag

* add input_ids_seq_length to test

* retrigger checks

* retrigger checks

Co-authored-by: Kevin Bondzio <kev@AIM-LAP-02.local>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Kevin Bondzio <kev@AIM-LAP-02.fritz.box>
This commit is contained in:
Kevin Bondzio
2022-03-11 19:36:44 +01:00
committed by GitHub
parent 322c8533d7
commit 9442b3ce31
6 changed files with 100 additions and 5 deletions

View File

@@ -28,6 +28,7 @@ if is_torch_available():
from transformers.generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
@@ -504,3 +505,35 @@ class LogitsProcessorTest(unittest.TestCase):
atol=1e-6,
)
)
def test_exponential_decay_length_penalty(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
penalty_start = 5
penalty_factor = 1.1
input_ids = ids_tensor((batch_size, 2), vocab_size=vocab_size)
input_ids_seq_length = input_ids.shape[-1]
length_decay_processor = ExponentialDecayLengthPenalty(
exponential_decay_length_penalty=(penalty_start, penalty_factor),
eos_token_id=eos_token_id,
input_ids_seq_length=input_ids_seq_length,
)
# check that penalty is not applied before start
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_start = length_decay_processor(input_ids, scores)
self.assertListEqual(scores_before_start[:, eos_token_id].tolist(), scores[:, eos_token_id].tolist())
# check that penalty is applied after start
input_ids = ids_tensor((batch_size, 20), vocab_size=vocab_size)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_after_start = length_decay_processor(input_ids, scores)
self.assertTrue(
torch.gt(
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
).all()
)