From da5bb2921907c398e61ea1b73fd22d13938fc427 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 29 Aug 2022 18:46:30 +0200 Subject: [PATCH] send model to the correct device (#18800) Co-authored-by: ydshieh --- .../test_modeling_speech_encoder_decoder.py | 1 + .../test_modeling_vision_encoder_decoder.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py index f241501302..3ecca17324 100644 --- a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py +++ b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py @@ -403,6 +403,7 @@ class EncoderDecoderMixin: ) model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + model.to(torch_device) model.train() model.gradient_checkpointing_enable() model.config.decoder_start_token_id = 0 diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py index fbac8b898a..279614371b 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py @@ -331,6 +331,7 @@ class EncoderDecoderMixin: ) model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + model.to(torch_device) model.train() model.gradient_checkpointing_enable() model.config.decoder_start_token_id = 0