💄 super

This commit is contained in:
Julien Chaumond
2020-01-15 18:33:50 -05:00
parent cd51893d37
commit 83a41d39b3
75 changed files with 328 additions and 328 deletions

View File

@@ -37,7 +37,7 @@ class PreTrainedEncoderDecoder(nn.Module):
"""
def __init__(self, encoder, decoder):
super(PreTrainedEncoderDecoder, self).__init__()
super().__init__()
self.encoder = encoder
self.decoder = decoder
@@ -290,7 +290,7 @@ class Model2Model(PreTrainedEncoderDecoder):
"""
def __init__(self, *args, **kwargs):
super(Model2Model, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.tie_weights()
def tie_weights(self):
@@ -321,7 +321,7 @@ class Model2Model(PreTrainedEncoderDecoder):
):
raise ValueError("Only the Bert model is currently supported.")
model = super(Model2Model, cls).from_pretrained(
model = super().from_pretrained(
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
*args,
@@ -345,5 +345,5 @@ class Model2LSTM(PreTrainedEncoderDecoder):
" E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`"
)
kwargs["decoder_model"] = torch.nn.LSTM(kwargs.pop("decoder_config"))
model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs)
model = super().from_pretrained(*args, **kwargs)
return model