rename seq2seq to encoder_decoder

This commit is contained in:
Rémi Louf
2019-10-30 10:54:46 +01:00
parent 9c1bdb5b61
commit 3b0d2fa30e
4 changed files with 14 additions and 16 deletions

View File

@@ -10,7 +10,7 @@ similar API between the different models.
| [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
| [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. |
| [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks.
| [Seq2seq Model fine-tuning](#seq2seq-model-fine-tuning) | Fine-tuning the library models for seq2seq tasks on the CNN/Daily Mail dataset. |
| [Abstractive summarization](#abstractive-summarization) | Fine-tuning the library models for abstractive summarization tasks on the CNN/Daily Mail dataset. |
## Language model fine-tuning
@@ -391,7 +391,7 @@ exact_match = 86.91
This fine-tuned model is available as a checkpoint under the reference
`bert-large-uncased-whole-word-masking-finetuned-squad`.
## Seq2seq model fine-tuning
## Abstractive summarization
Based on the script
[`run_summarization_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_summarization_finetuning.py).
@@ -408,8 +408,6 @@ note that the finetuning script **will not work** if you do not download both
datasets. We will refer as `$DATA_PATH` the path to where you uncompressed both
archive.
## Bert2Bert and abstractive summarization
```bash
export DATA_PATH=/path/to/dataset/

View File

@@ -32,7 +32,7 @@ from transformers import (
AutoTokenizer,
BertForMaskedLM,
BertConfig,
PreTrainedSeq2seq,
PreTrainedEncoderDecoder,
Model2Model,
)
@@ -475,7 +475,7 @@ def main():
for checkpoint in checkpoints:
encoder_checkpoint = os.path.join(checkpoint, "encoder")
decoder_checkpoint = os.path.join(checkpoint, "decoder")
model = PreTrainedSeq2seq.from_pretrained(
model = PreTrainedEncoderDecoder.from_pretrained(
encoder_checkpoint, decoder_checkpoint
)
model.to(args.device)