Generate: add missing **model_kwargs in sample tests (#18696)

This commit is contained in:
Joao Gante
2022-08-19 16:14:27 +01:00
committed by GitHub
parent e54a1b49aa
commit e95d433d77

View File

@@ -327,6 +327,7 @@ class GenerationTesterMixin:
remove_invalid_values=True, remove_invalid_values=True,
**logits_warper_kwargs, **logits_warper_kwargs,
**process_kwargs, **process_kwargs,
**model_kwargs,
) )
torch.manual_seed(0) torch.manual_seed(0)
@@ -361,6 +362,7 @@ class GenerationTesterMixin:
**kwargs, **kwargs,
**model_kwargs, **model_kwargs,
) )
return output_sample, output_generate return output_sample, output_generate
def _beam_search_generate( def _beam_search_generate(