Generate: add missing **model_kwargs in sample tests (#18696)

This commit is contained in:
Joao Gante
2022-08-19 16:14:27 +01:00
committed by GitHub
parent e54a1b49aa
commit e95d433d77

View File

@@ -327,6 +327,7 @@ class GenerationTesterMixin:
remove_invalid_values=True,
**logits_warper_kwargs,
**process_kwargs,
**model_kwargs,
)
torch.manual_seed(0)
@@ -361,6 +362,7 @@ class GenerationTesterMixin:
**kwargs,
**model_kwargs,
)
return output_sample, output_generate
def _beam_search_generate(