💄 super
This commit is contained in:
@@ -37,7 +37,7 @@ class PreTrainedEncoderDecoder(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, encoder, decoder):
|
||||
super(PreTrainedEncoderDecoder, self).__init__()
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
@@ -290,7 +290,7 @@ class Model2Model(PreTrainedEncoderDecoder):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Model2Model, self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
@@ -321,7 +321,7 @@ class Model2Model(PreTrainedEncoderDecoder):
|
||||
):
|
||||
raise ValueError("Only the Bert model is currently supported.")
|
||||
|
||||
model = super(Model2Model, cls).from_pretrained(
|
||||
model = super().from_pretrained(
|
||||
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
*args,
|
||||
@@ -345,5 +345,5 @@ class Model2LSTM(PreTrainedEncoderDecoder):
|
||||
" E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`"
|
||||
)
|
||||
kwargs["decoder_model"] = torch.nn.LSTM(kwargs.pop("decoder_config"))
|
||||
model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs)
|
||||
model = super().from_pretrained(*args, **kwargs)
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user