diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py index 1181b94789..8f565aec06 100644 --- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py @@ -618,8 +618,10 @@ class EncoderDecoderMixin: ) model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) - model.train() + model.to(torch_device) model.gradient_checkpointing_enable() + model.train() + model.config.decoder_start_token_id = 0 model.config.pad_token_id = 0 @@ -629,6 +631,8 @@ class EncoderDecoderMixin: "labels": inputs_dict["labels"], "decoder_input_ids": inputs_dict["decoder_input_ids"], } + model_inputs = {k: v.to(torch_device) for k, v in model_inputs.items()} + loss = model(**model_inputs).loss loss.backward()