Generate: add missing **model_kwargs in sample tests (#18696)
This commit is contained in:
@@ -327,6 +327,7 @@ class GenerationTesterMixin:
|
||||
remove_invalid_values=True,
|
||||
**logits_warper_kwargs,
|
||||
**process_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
@@ -361,6 +362,7 @@ class GenerationTesterMixin:
|
||||
**kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return output_sample, output_generate
|
||||
|
||||
def _beam_search_generate(
|
||||
|
||||
Reference in New Issue
Block a user