send model to the correct device (#18800)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -403,6 +403,7 @@ class EncoderDecoderMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
model.to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
model.config.decoder_start_token_id = 0
|
model.config.decoder_start_token_id = 0
|
||||||
|
|||||||
@@ -331,6 +331,7 @@ class EncoderDecoderMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
model.to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
model.config.decoder_start_token_id = 0
|
model.config.decoder_start_token_id = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user