rename seq2seq to encoder_decoder

This commit is contained in:
Rémi Louf
2019-10-30 10:54:46 +01:00
parent 9c1bdb5b61
commit 3b0d2fa30e
4 changed files with 14 additions and 16 deletions

View File

@@ -32,7 +32,7 @@ from transformers import (
AutoTokenizer,
BertForMaskedLM,
BertConfig,
PreTrainedSeq2seq,
PreTrainedEncoderDecoder,
Model2Model,
)
@@ -475,7 +475,7 @@ def main():
for checkpoint in checkpoints:
encoder_checkpoint = os.path.join(checkpoint, "encoder")
decoder_checkpoint = os.path.join(checkpoint, "decoder")
model = PreTrainedSeq2seq.from_pretrained(
model = PreTrainedEncoderDecoder.from_pretrained(
encoder_checkpoint, decoder_checkpoint
)
model.to(args.device)