send model to the correct device (#18800)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -403,6 +403,7 @@ class EncoderDecoderMixin:
|
||||
)
|
||||
|
||||
model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.decoder_start_token_id = 0
|
||||
|
||||
Reference in New Issue
Block a user