From e95d433d77727a9babadf008dd621a2326d37303 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 19 Aug 2022 16:14:27 +0100 Subject: [PATCH] Generate: add missing `**model_kwargs` in sample tests (#18696) --- tests/generation/test_generation_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index ba13669368..62a3f588cf 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -327,6 +327,7 @@ class GenerationTesterMixin: remove_invalid_values=True, **logits_warper_kwargs, **process_kwargs, + **model_kwargs, ) torch.manual_seed(0) @@ -361,6 +362,7 @@ class GenerationTesterMixin: **kwargs, **model_kwargs, ) + return output_sample, output_generate def _beam_search_generate(