[generate] do_sample default back to False (#3298)

* change do_samples back

* None better default as boolean

* adapt do_sample to True in test example

* make style
This commit is contained in:
Patrick von Platen
2020-03-17 15:52:37 +01:00
committed by GitHub
parent 2187c49f5c
commit e8f44af5bf
6 changed files with 30 additions and 22 deletions

View File

@@ -638,16 +638,16 @@ class ModelTesterMixin:
if config.bos_token_id is None:
with self.assertRaises(AssertionError):
model.generate(max_length=5)
model.generate(do_sample=True, max_length=5)
# batch_size = 1
self._check_generated_tokens(model.generate(input_ids))
self._check_generated_tokens(model.generate(input_ids, do_sample=True))
# batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(input_ids, num_beams=3))
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_beams=3))
else:
# batch_size = 1
self._check_generated_tokens(model.generate(max_length=5))
self._check_generated_tokens(model.generate(do_sample=True, max_length=5))
# batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(max_length=5, num_beams=3))
self._check_generated_tokens(model.generate(do_sample=True, max_length=5, num_beams=3))
with self.assertRaises(AssertionError):
# generating multiple sequences when greedy no beam generation
@@ -659,12 +659,14 @@ class ModelTesterMixin:
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
# batch_size > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_return_sequences=3))
# batch_size > 1, greedy
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
# batch_size > 1, num_beams > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
self._check_generated_tokens(
model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,)
)
# batch_size > 1, num_beams > 1, greedy
self._check_generated_tokens(
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)