rename seq2seq to encoder_decoder
This commit is contained in:
@@ -10,7 +10,7 @@ similar API between the different models.
|
||||
| [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
|
||||
| [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. |
|
||||
| [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks.
|
||||
| [Seq2seq Model fine-tuning](#seq2seq-model-fine-tuning) | Fine-tuning the library models for seq2seq tasks on the CNN/Daily Mail dataset. |
|
||||
| [Abstractive summarization](#abstractive-summarization) | Fine-tuning the library models for abstractive summarization tasks on the CNN/Daily Mail dataset. |
|
||||
|
||||
## Language model fine-tuning
|
||||
|
||||
@@ -391,7 +391,7 @@ exact_match = 86.91
|
||||
This fine-tuned model is available as a checkpoint under the reference
|
||||
`bert-large-uncased-whole-word-masking-finetuned-squad`.
|
||||
|
||||
## Seq2seq model fine-tuning
|
||||
## Abstractive summarization
|
||||
|
||||
Based on the script
|
||||
[`run_summarization_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_summarization_finetuning.py).
|
||||
@@ -408,8 +408,6 @@ note that the finetuning script **will not work** if you do not download both
|
||||
datasets. We will refer as `$DATA_PATH` the path to where you uncompressed both
|
||||
archive.
|
||||
|
||||
## Bert2Bert and abstractive summarization
|
||||
|
||||
```bash
|
||||
export DATA_PATH=/path/to/dataset/
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
BertForMaskedLM,
|
||||
BertConfig,
|
||||
PreTrainedSeq2seq,
|
||||
PreTrainedEncoderDecoder,
|
||||
Model2Model,
|
||||
)
|
||||
|
||||
@@ -475,7 +475,7 @@ def main():
|
||||
for checkpoint in checkpoints:
|
||||
encoder_checkpoint = os.path.join(checkpoint, "encoder")
|
||||
decoder_checkpoint = os.path.join(checkpoint, "decoder")
|
||||
model = PreTrainedSeq2seq.from_pretrained(
|
||||
model = PreTrainedEncoderDecoder.from_pretrained(
|
||||
encoder_checkpoint, decoder_checkpoint
|
||||
)
|
||||
model.to(args.device)
|
||||
|
||||
@@ -87,7 +87,7 @@ if is_torch_available():
|
||||
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
|
||||
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_seq2seq import PreTrainedSeq2seq, Model2Model
|
||||
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
|
||||
|
||||
# Optimization
|
||||
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Auto Model class. """
|
||||
""" Classes to support Encoder-Decoder architectures """
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
@@ -27,9 +27,9 @@ from .modeling_auto import AutoModel, AutoModelWithLMHead
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreTrainedSeq2seq(nn.Module):
|
||||
class PreTrainedEncoderDecoder(nn.Module):
|
||||
r"""
|
||||
:class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be
|
||||
:class:`~transformers.PreTrainedEncoderDecoder` is a generic model class that will be
|
||||
instantiated as a transformer architecture with one of the base model
|
||||
classes of the library as encoder and (optionally) another one as
|
||||
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
|
||||
@@ -37,7 +37,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, encoder, decoder):
|
||||
super(PreTrainedSeq2seq, self).__init__()
|
||||
super(PreTrainedEncoderDecoder, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
@@ -107,7 +107,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
|
||||
Examples::
|
||||
|
||||
model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
|
||||
model = PreTrainedEncoderDecoder.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
|
||||
"""
|
||||
|
||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||
@@ -155,7 +155,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a Seq2Seq model and its configuration file in a format such
|
||||
that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained`
|
||||
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
|
||||
|
||||
We save the encoder' and decoder's parameters in two separate directories.
|
||||
"""
|
||||
@@ -219,7 +219,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
|
||||
class Model2Model(PreTrainedSeq2seq):
|
||||
class Model2Model(PreTrainedEncoderDecoder):
|
||||
r"""
|
||||
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
|
||||
where both of the encoder and decoder are of the same family. If the
|
||||
@@ -277,14 +277,14 @@ class Model2Model(PreTrainedSeq2seq):
|
||||
return model
|
||||
|
||||
|
||||
class Model2LSTM(PreTrainedSeq2seq):
|
||||
class Model2LSTM(PreTrainedEncoderDecoder):
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
if kwargs.get("decoder_model", None) is None:
|
||||
# We will create a randomly initilized LSTM model as decoder
|
||||
if "decoder_config" not in kwargs:
|
||||
raise ValueError(
|
||||
"To load an LSTM in Seq2seq model, please supply either: "
|
||||
"To load an LSTM in Encoder-Decoder model, please supply either: "
|
||||
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
|
||||
" - a dictionary of configuration parameters that will be used to initialize a"
|
||||
" torch.nn.LSTM model as `decoder_config` keyword argument. "
|
||||
Reference in New Issue
Block a user