This commit is contained in:
Patrick von Platen
2021-09-06 16:09:24 +02:00
committed by GitHub
parent 6b29bff852
commit 607611f240

View File

@@ -201,6 +201,7 @@ class EncoderDecoderMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
enc_dec_model.save_pretrained(tmpdirname)
enc_dec_model = EncoderDecoderModel.from_pretrained(tmpdirname)
enc_dec_model.to(torch_device)
after_outputs = enc_dec_model(
input_ids=input_ids,
@@ -245,6 +246,7 @@ class EncoderDecoderMixin:
encoder_pretrained_model_name_or_path=encoder_tmp_dirname,
decoder_pretrained_model_name_or_path=decoder_tmp_dirname,
)
enc_dec_model.to(torch_device)
after_outputs = enc_dec_model(
input_ids=input_ids,