rename seq2seq to encoder_decoder

This commit is contained in:
Rémi Louf
2019-10-30 10:54:46 +01:00
parent 9c1bdb5b61
commit 3b0d2fa30e
4 changed files with 14 additions and 16 deletions

View File

@@ -87,7 +87,7 @@ if is_torch_available():
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_seq2seq import PreTrainedSeq2seq, Model2Model
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
# Optimization
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,

View File

@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """
""" Classes to support Encoder-Decoder architectures """
from __future__ import absolute_import, division, print_function, unicode_literals
@@ -27,9 +27,9 @@ from .modeling_auto import AutoModel, AutoModelWithLMHead
logger = logging.getLogger(__name__)
class PreTrainedSeq2seq(nn.Module):
class PreTrainedEncoderDecoder(nn.Module):
r"""
:class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be
:class:`~transformers.PreTrainedEncoderDecoder` is a generic model class that will be
instantiated as a transformer architecture with one of the base model
classes of the library as encoder and (optionally) another one as
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
@@ -37,7 +37,7 @@ class PreTrainedSeq2seq(nn.Module):
"""
def __init__(self, encoder, decoder):
super(PreTrainedSeq2seq, self).__init__()
super(PreTrainedEncoderDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
@@ -107,7 +107,7 @@ class PreTrainedSeq2seq(nn.Module):
Examples::
model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
model = PreTrainedEncoderDecoder.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
"""
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
@@ -155,7 +155,7 @@ class PreTrainedSeq2seq(nn.Module):
def save_pretrained(self, save_directory):
""" Save a Seq2Seq model and its configuration file in a format such
that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained`
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
We save the encoder' and decoder's parameters in two separate directories.
"""
@@ -219,7 +219,7 @@ class PreTrainedSeq2seq(nn.Module):
return decoder_outputs + encoder_outputs
class Model2Model(PreTrainedSeq2seq):
class Model2Model(PreTrainedEncoderDecoder):
r"""
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
where both of the encoder and decoder are of the same family. If the
@@ -277,14 +277,14 @@ class Model2Model(PreTrainedSeq2seq):
return model
class Model2LSTM(PreTrainedSeq2seq):
class Model2LSTM(PreTrainedEncoderDecoder):
@classmethod
def from_pretrained(cls, *args, **kwargs):
if kwargs.get("decoder_model", None) is None:
# We will create a randomly initilized LSTM model as decoder
if "decoder_config" not in kwargs:
raise ValueError(
"To load an LSTM in Seq2seq model, please supply either: "
"To load an LSTM in Encoder-Decoder model, please supply either: "
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
" - a dictionary of configuration parameters that will be used to initialize a"
" torch.nn.LSTM model as `decoder_config` keyword argument. "