up (#13448)
This commit is contained in:
committed by
GitHub
parent
6b29bff852
commit
607611f240
@@ -201,6 +201,7 @@ class EncoderDecoderMixin:
|
|||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
enc_dec_model.save_pretrained(tmpdirname)
|
enc_dec_model.save_pretrained(tmpdirname)
|
||||||
enc_dec_model = EncoderDecoderModel.from_pretrained(tmpdirname)
|
enc_dec_model = EncoderDecoderModel.from_pretrained(tmpdirname)
|
||||||
|
enc_dec_model.to(torch_device)
|
||||||
|
|
||||||
after_outputs = enc_dec_model(
|
after_outputs = enc_dec_model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -245,6 +246,7 @@ class EncoderDecoderMixin:
|
|||||||
encoder_pretrained_model_name_or_path=encoder_tmp_dirname,
|
encoder_pretrained_model_name_or_path=encoder_tmp_dirname,
|
||||||
decoder_pretrained_model_name_or_path=decoder_tmp_dirname,
|
decoder_pretrained_model_name_or_path=decoder_tmp_dirname,
|
||||||
)
|
)
|
||||||
|
enc_dec_model.to(torch_device)
|
||||||
|
|
||||||
after_outputs = enc_dec_model(
|
after_outputs = enc_dec_model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user