committed by
GitHub
parent
667b823b89
commit
2410d0f8ed
@@ -1949,11 +1949,14 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
def test_custom_logits_processor(self):
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random", min_length=1).to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0))
|
||||
# it should not be allowed to both define `min_length` via config and `logits_processor` list
|
||||
with self.assertRaises(ValueError):
|
||||
bart_model.generate(input_ids, logits_processor=logits_processor)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user