Add soft length regulation for sequence generation (#15245)
* add possibility to softly regulate length when using sampling method in model.generate() function * fix test config, fix formatting * fix rag integration, fix docstyling * fix wrong docstring * change param to tuple, add test * fix old param in rag_model, remove unused import * change test according to new param * fix formatting * fix test case * fix doc style * move start_length calculation to Logitprocessor * add possibility to softly regulate length when using sampling method in model.generate() function * fix rag integration, fix docstyling * fix test config, fix formatting * change param to tuple, add test * fix old param in rag_model, remove unused import * add possibility to softly regulate length when using sampling method in model.generate() function * change param to tuple, add test * fix old param in rag_model, remove unused import * remove unused import * fix small errors * fix test * add possibility to softly regulate length when using sampling method in model.generate() function * fix test config, fix formatting * fix rag integration, fix docstyling * change param to tuple, add test * fix old param in rag_model, remove unused import * change test according to new param * fix test case * move start_length calculation to Logitprocessor * add possibility to softly regulate length when using sampling method in model.generate() function * fix rag integration, fix docstyling * fix test config, fix formatting * change param to tuple, add test * fix old param in rag_model, remove unused import * add possibility to softly regulate length when using sampling method in model.generate() function * fix test config, fix formatting * fix rag integration, fix docstyling * add possibility to softly regulate length when using sampling method in model.generate() function * fix rag integration, fix docstyling * change param to tuple, add test * fix old param in rag_model, remove unused import * fix small errors * Update src/transformers/generation_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/generation_utils.py * Update src/transformers/generation_utils.py * fix docstring, add type ind model rag * fix docstrings * introduce seq_length variable for cleaner code * fix black formatting * add input_ids_seq_length to modeling_rag * add input_ids_seq_length to test * retrigger checks * retrigger checks Co-authored-by: Kevin Bondzio <kev@AIM-LAP-02.local> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Kevin Bondzio <kev@AIM-LAP-02.fritz.box>
This commit is contained in:
@@ -28,6 +28,7 @@ if is_torch_available():
|
||||
|
||||
from transformers.generation_logits_process import (
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
ExponentialDecayLengthPenalty,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
@@ -504,3 +505,35 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
atol=1e-6,
|
||||
)
|
||||
)
|
||||
|
||||
def test_exponential_decay_length_penalty(self):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
|
||||
penalty_start = 5
|
||||
penalty_factor = 1.1
|
||||
|
||||
input_ids = ids_tensor((batch_size, 2), vocab_size=vocab_size)
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
|
||||
length_decay_processor = ExponentialDecayLengthPenalty(
|
||||
exponential_decay_length_penalty=(penalty_start, penalty_factor),
|
||||
eos_token_id=eos_token_id,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
)
|
||||
|
||||
# check that penalty is not applied before start
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_before_start = length_decay_processor(input_ids, scores)
|
||||
self.assertListEqual(scores_before_start[:, eos_token_id].tolist(), scores[:, eos_token_id].tolist())
|
||||
|
||||
# check that penalty is applied after start
|
||||
input_ids = ids_tensor((batch_size, 20), vocab_size=vocab_size)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_after_start = length_decay_processor(input_ids, scores)
|
||||
self.assertTrue(
|
||||
torch.gt(
|
||||
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
|
||||
).all()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user