Fix gradient checkpoint test in encoder-decoder (#20017)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -618,8 +618,10 @@ class EncoderDecoderMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
model.train()
|
model.to(torch_device)
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
model.train()
|
||||||
|
|
||||||
model.config.decoder_start_token_id = 0
|
model.config.decoder_start_token_id = 0
|
||||||
model.config.pad_token_id = 0
|
model.config.pad_token_id = 0
|
||||||
|
|
||||||
@@ -629,6 +631,8 @@ class EncoderDecoderMixin:
|
|||||||
"labels": inputs_dict["labels"],
|
"labels": inputs_dict["labels"],
|
||||||
"decoder_input_ids": inputs_dict["decoder_input_ids"],
|
"decoder_input_ids": inputs_dict["decoder_input_ids"],
|
||||||
}
|
}
|
||||||
|
model_inputs = {k: v.to(torch_device) for k, v in model_inputs.items()}
|
||||||
|
|
||||||
loss = model(**model_inputs).loss
|
loss = model(**model_inputs).loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user