From 3b0d2fa30eb9756c888b4ed36213350d4b6e70e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 30 Oct 2019 10:54:46 +0100 Subject: [PATCH] rename seq2seq to encoder_decoder --- examples/README.md | 6 ++---- examples/run_summarization_finetuning.py | 4 ++-- transformers/__init__.py | 2 +- ..._seq2seq.py => modeling_encoder_decoder.py} | 18 +++++++++--------- 4 files changed, 14 insertions(+), 16 deletions(-) rename transformers/{modeling_seq2seq.py => modeling_encoder_decoder.py} (96%) diff --git a/examples/README.md b/examples/README.md index bec6d57171..6d27a0c560 100644 --- a/examples/README.md +++ b/examples/README.md @@ -10,7 +10,7 @@ similar API between the different models. | [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. | | [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. | | [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks. -| [Seq2seq Model fine-tuning](#seq2seq-model-fine-tuning) | Fine-tuning the library models for seq2seq tasks on the CNN/Daily Mail dataset. | +| [Abstractive summarization](#abstractive-summarization) | Fine-tuning the library models for abstractive summarization tasks on the CNN/Daily Mail dataset. | ## Language model fine-tuning @@ -391,7 +391,7 @@ exact_match = 86.91 This fine-tuned model is available as a checkpoint under the reference `bert-large-uncased-whole-word-masking-finetuned-squad`. -## Seq2seq model fine-tuning +## Abstractive summarization Based on the script [`run_summarization_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_summarization_finetuning.py). @@ -408,8 +408,6 @@ note that the finetuning script **will not work** if you do not download both datasets. We will refer as `$DATA_PATH` the path to where you uncompressed both archive. -## Bert2Bert and abstractive summarization - ```bash export DATA_PATH=/path/to/dataset/ diff --git a/examples/run_summarization_finetuning.py b/examples/run_summarization_finetuning.py index 3d194950c7..448505c727 100644 --- a/examples/run_summarization_finetuning.py +++ b/examples/run_summarization_finetuning.py @@ -32,7 +32,7 @@ from transformers import ( AutoTokenizer, BertForMaskedLM, BertConfig, - PreTrainedSeq2seq, + PreTrainedEncoderDecoder, Model2Model, ) @@ -475,7 +475,7 @@ def main(): for checkpoint in checkpoints: encoder_checkpoint = os.path.join(checkpoint, "encoder") decoder_checkpoint = os.path.join(checkpoint, "decoder") - model = PreTrainedSeq2seq.from_pretrained( + model = PreTrainedEncoderDecoder.from_pretrained( encoder_checkpoint, decoder_checkpoint ) model.to(args.device) diff --git a/transformers/__init__.py b/transformers/__init__.py index 2206a0302e..844aa22295 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -87,7 +87,7 @@ if is_torch_available(): from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) - from .modeling_seq2seq import PreTrainedSeq2seq, Model2Model + from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model # Optimization from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, diff --git a/transformers/modeling_seq2seq.py b/transformers/modeling_encoder_decoder.py similarity index 96% rename from transformers/modeling_seq2seq.py rename to transformers/modeling_encoder_decoder.py index ba8c546a30..162e2f8b3b 100644 --- a/transformers/modeling_seq2seq.py +++ b/transformers/modeling_encoder_decoder.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Auto Model class. """ +""" Classes to support Encoder-Decoder architectures """ from __future__ import absolute_import, division, print_function, unicode_literals @@ -27,9 +27,9 @@ from .modeling_auto import AutoModel, AutoModelWithLMHead logger = logging.getLogger(__name__) -class PreTrainedSeq2seq(nn.Module): +class PreTrainedEncoderDecoder(nn.Module): r""" - :class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be + :class:`~transformers.PreTrainedEncoderDecoder` is a generic model class that will be instantiated as a transformer architecture with one of the base model classes of the library as encoder and (optionally) another one as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` @@ -37,7 +37,7 @@ class PreTrainedSeq2seq(nn.Module): """ def __init__(self, encoder, decoder): - super(PreTrainedSeq2seq, self).__init__() + super(PreTrainedEncoderDecoder, self).__init__() self.encoder = encoder self.decoder = decoder @@ -107,7 +107,7 @@ class PreTrainedSeq2seq(nn.Module): Examples:: - model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert + model = PreTrainedEncoderDecoder.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert """ # keyword arguments come in 3 flavors: encoder-specific (prefixed by @@ -155,7 +155,7 @@ class PreTrainedSeq2seq(nn.Module): def save_pretrained(self, save_directory): """ Save a Seq2Seq model and its configuration file in a format such - that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained` + that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained` We save the encoder' and decoder's parameters in two separate directories. """ @@ -219,7 +219,7 @@ class PreTrainedSeq2seq(nn.Module): return decoder_outputs + encoder_outputs -class Model2Model(PreTrainedSeq2seq): +class Model2Model(PreTrainedEncoderDecoder): r""" :class:`~transformers.Model2Model` instantiates a Seq2Seq2 model where both of the encoder and decoder are of the same family. If the @@ -277,14 +277,14 @@ class Model2Model(PreTrainedSeq2seq): return model -class Model2LSTM(PreTrainedSeq2seq): +class Model2LSTM(PreTrainedEncoderDecoder): @classmethod def from_pretrained(cls, *args, **kwargs): if kwargs.get("decoder_model", None) is None: # We will create a randomly initilized LSTM model as decoder if "decoder_config" not in kwargs: raise ValueError( - "To load an LSTM in Seq2seq model, please supply either: " + "To load an LSTM in Encoder-Decoder model, please supply either: " " - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or" " - a dictionary of configuration parameters that will be used to initialize a" " torch.nn.LSTM model as `decoder_config` keyword argument. "