Fix gradient checkpoint test in encoder-decoder (#20017)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -618,8 +618,10 @@ class EncoderDecoderMixin:
|
||||
)
|
||||
|
||||
model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
model.train()
|
||||
model.to(torch_device)
|
||||
model.gradient_checkpointing_enable()
|
||||
model.train()
|
||||
|
||||
model.config.decoder_start_token_id = 0
|
||||
model.config.pad_token_id = 0
|
||||
|
||||
@@ -629,6 +631,8 @@ class EncoderDecoderMixin:
|
||||
"labels": inputs_dict["labels"],
|
||||
"decoder_input_ids": inputs_dict["decoder_input_ids"],
|
||||
}
|
||||
model_inputs = {k: v.to(torch_device) for k, v in model_inputs.items()}
|
||||
|
||||
loss = model(**model_inputs).loss
|
||||
loss.backward()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user