From e47765d884673e7ee420ed06b4551bfc3d755c8c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 11 Jun 2021 09:04:07 +0100 Subject: [PATCH] Fix head masking generate tests (#12110) * fix_torch_device_generate_test * remove @ * fix tests --- tests/test_generation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index ed28c77c07..de986b696d 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -1078,7 +1078,7 @@ class GenerationTesterMixin: attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - model = model_class(config) + model = model_class(config).to(torch_device) # We want to test only encoder-decoder models if not config.is_encoder_decoder: continue