rename seq2seq to encoder_decoder
This commit is contained in:
@@ -87,7 +87,7 @@ if is_torch_available():
|
||||
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
|
||||
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_seq2seq import PreTrainedSeq2seq, Model2Model
|
||||
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
|
||||
|
||||
# Optimization
|
||||
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Auto Model class. """
|
||||
""" Classes to support Encoder-Decoder architectures """
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
@@ -27,9 +27,9 @@ from .modeling_auto import AutoModel, AutoModelWithLMHead
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreTrainedSeq2seq(nn.Module):
|
||||
class PreTrainedEncoderDecoder(nn.Module):
|
||||
r"""
|
||||
:class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be
|
||||
:class:`~transformers.PreTrainedEncoderDecoder` is a generic model class that will be
|
||||
instantiated as a transformer architecture with one of the base model
|
||||
classes of the library as encoder and (optionally) another one as
|
||||
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
|
||||
@@ -37,7 +37,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, encoder, decoder):
|
||||
super(PreTrainedSeq2seq, self).__init__()
|
||||
super(PreTrainedEncoderDecoder, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
@@ -107,7 +107,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
|
||||
Examples::
|
||||
|
||||
model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
|
||||
model = PreTrainedEncoderDecoder.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
|
||||
"""
|
||||
|
||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||
@@ -155,7 +155,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a Seq2Seq model and its configuration file in a format such
|
||||
that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained`
|
||||
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
|
||||
|
||||
We save the encoder' and decoder's parameters in two separate directories.
|
||||
"""
|
||||
@@ -219,7 +219,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
|
||||
class Model2Model(PreTrainedSeq2seq):
|
||||
class Model2Model(PreTrainedEncoderDecoder):
|
||||
r"""
|
||||
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
|
||||
where both of the encoder and decoder are of the same family. If the
|
||||
@@ -277,14 +277,14 @@ class Model2Model(PreTrainedSeq2seq):
|
||||
return model
|
||||
|
||||
|
||||
class Model2LSTM(PreTrainedSeq2seq):
|
||||
class Model2LSTM(PreTrainedEncoderDecoder):
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
if kwargs.get("decoder_model", None) is None:
|
||||
# We will create a randomly initilized LSTM model as decoder
|
||||
if "decoder_config" not in kwargs:
|
||||
raise ValueError(
|
||||
"To load an LSTM in Seq2seq model, please supply either: "
|
||||
"To load an LSTM in Encoder-Decoder model, please supply either: "
|
||||
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
|
||||
" - a dictionary of configuration parameters that will be used to initialize a"
|
||||
" torch.nn.LSTM model as `decoder_config` keyword argument. "
|
||||
Reference in New Issue
Block a user