[generate] do_sample default back to False (#3298)
* change do_samples back * None better default as boolean * adapt do_sample to True in test example * make style
This commit is contained in:
committed by
GitHub
parent
2187c49f5c
commit
e8f44af5bf
@@ -638,16 +638,16 @@ class ModelTesterMixin:
|
||||
|
||||
if config.bos_token_id is None:
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(max_length=5)
|
||||
model.generate(do_sample=True, max_length=5)
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(input_ids))
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=True))
|
||||
# batch_size = 1, num_beams > 1
|
||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3))
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_beams=3))
|
||||
else:
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(max_length=5))
|
||||
self._check_generated_tokens(model.generate(do_sample=True, max_length=5))
|
||||
# batch_size = 1, num_beams > 1
|
||||
self._check_generated_tokens(model.generate(max_length=5, num_beams=3))
|
||||
self._check_generated_tokens(model.generate(do_sample=True, max_length=5, num_beams=3))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating multiple sequences when greedy no beam generation
|
||||
@@ -659,12 +659,14 @@ class ModelTesterMixin:
|
||||
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# batch_size > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_return_sequences=3))
|
||||
# batch_size > 1, greedy
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
|
||||
|
||||
# batch_size > 1, num_beams > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
|
||||
self._check_generated_tokens(
|
||||
model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,)
|
||||
)
|
||||
# batch_size > 1, num_beams > 1, greedy
|
||||
self._check_generated_tokens(
|
||||
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)
|
||||
|
||||
Reference in New Issue
Block a user