rename seq2seq to encoder_decoder
This commit is contained in:
@@ -32,7 +32,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
BertForMaskedLM,
|
||||
BertConfig,
|
||||
PreTrainedSeq2seq,
|
||||
PreTrainedEncoderDecoder,
|
||||
Model2Model,
|
||||
)
|
||||
|
||||
@@ -475,7 +475,7 @@ def main():
|
||||
for checkpoint in checkpoints:
|
||||
encoder_checkpoint = os.path.join(checkpoint, "encoder")
|
||||
decoder_checkpoint = os.path.join(checkpoint, "decoder")
|
||||
model = PreTrainedSeq2seq.from_pretrained(
|
||||
model = PreTrainedEncoderDecoder.from_pretrained(
|
||||
encoder_checkpoint, decoder_checkpoint
|
||||
)
|
||||
model.to(args.device)
|
||||
|
||||
Reference in New Issue
Block a user