send model to the correct device (#18800)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-08-29 18:46:30 +02:00
committed by GitHub
parent f1fd460694
commit da5bb29219
2 changed files with 2 additions and 0 deletions

View File

@@ -403,6 +403,7 @@ class EncoderDecoderMixin:
)
model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
model.to(torch_device)
model.train()
model.gradient_checkpointing_enable()
model.config.decoder_start_token_id = 0