[testing] fix ambiguous test (#6898)
Since `generate()` does:
```
num_beams = num_beams if num_beams is not None else self.config.num_beams
```
This test fails if `model.config.num_beams > 1` (which is the case in the model I'm porting).
This fix makes the test setup unambiguous by passing an explicit `num_beams=1` to `generate()`.
Thanks.
This commit is contained in:
@@ -822,7 +822,7 @@ class ModelTesterMixin:
|
|||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
# generating multiple sequences when no beam search generation
|
# generating multiple sequences when no beam search generation
|
||||||
# is not allowed as it would always generate the same sequences
|
# is not allowed as it would always generate the same sequences
|
||||||
model.generate(input_ids, do_sample=False, num_return_sequences=2)
|
model.generate(input_ids, do_sample=False, num_beams=1, num_return_sequences=2)
|
||||||
|
|
||||||
# num_return_sequences > 1, sample
|
# num_return_sequences > 1, sample
|
||||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=2))
|
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=2))
|
||||||
|
|||||||
Reference in New Issue
Block a user