Delete untested, broken Model2LSTM (#2968)
This commit is contained in:
@@ -18,7 +18,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
||||||
@@ -294,21 +293,3 @@ class Model2Model(PreTrainedEncoderDecoder):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class Model2LSTM(PreTrainedEncoderDecoder):
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
if kwargs.get("decoder_model", None) is None:
|
|
||||||
# We will create a randomly initilized LSTM model as decoder
|
|
||||||
if "decoder_config" not in kwargs:
|
|
||||||
raise ValueError(
|
|
||||||
"To load an LSTM in Encoder-Decoder model, please supply either: "
|
|
||||||
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
|
|
||||||
" - a dictionary of configuration parameters that will be used to initialize a"
|
|
||||||
" torch.nn.LSTM model as `decoder_config` keyword argument. "
|
|
||||||
" E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`"
|
|
||||||
)
|
|
||||||
kwargs["decoder_model"] = torch.nn.LSTM(kwargs.pop("decoder_config"))
|
|
||||||
model = super().from_pretrained(*args, **kwargs)
|
|
||||||
return model
|
|
||||||
|
|||||||
Reference in New Issue
Block a user