diff --git a/transformers/modeling_encoder_decoder.py b/transformers/modeling_encoder_decoder.py index 162e2f8b3b..a884abd0a2 100644 --- a/transformers/modeling_encoder_decoder.py +++ b/transformers/modeling_encoder_decoder.py @@ -114,23 +114,28 @@ class PreTrainedEncoderDecoder(nn.Module): # `encoder_`), decoder-specific (prefixed by `decoder_`) and those # that apply to the model as a whole. # We let the specific kwargs override the common ones in case of conflict. - kwargs_encoder = { - argument[len("encoder_"):]: value - for argument, value in kwargs.items() - if argument.startswith("encoder_") - } - kwargs_decoder = { - argument[len("decoder_"):]: value - for argument, value in kwargs.items() - if argument.startswith("decoder_") - } kwargs_common = { argument: value for argument, value in kwargs.items() - if not (argument.startswith("encoder_") or argument.startswith("decoder_")) + if not argument.startswith("encoder_") + and not argument.startswith("decoder_") } - kwargs_decoder = dict(kwargs_common, **kwargs_decoder) - kwargs_encoder = dict(kwargs_common, **kwargs_encoder) + kwargs_decoder = kwargs_common.copy() + kwargs_encoder = kwargs_common.copy() + kwargs_encoder.update( + { + argument[len("encoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("encoder_") + } + ) + kwargs_decoder.update( + { + argument[len("decoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("decoder_") + } + ) # Load and initialize the encoder and decoder # The distinction between encoder and decoder at the model level is made @@ -185,35 +190,44 @@ class PreTrainedEncoderDecoder(nn.Module): # `encoder_`), decoder-specific (prefixed by `decoder_`) and those # that apply to the model as whole. # We let the specific kwargs override the common ones in case of conflict. - kwargs_encoder = { - argument[len("encoder_"):]: value - for argument, value in kwargs.items() - if argument.startswith("encoder_") - } - kwargs_decoder = { - argument[len("decoder_"):]: value - for argument, value in kwargs.items() - if argument.startswith("decoder_") - } kwargs_common = { argument: value for argument, value in kwargs.items() - if not (argument.startswith("encoder_") or argument.startswith("decoder_")) + if not argument.startswith("encoder_") + and not argument.startswith("decoder_") } - kwargs_decoder = dict(kwargs_common, **kwargs_decoder) - kwargs_encoder = dict(kwargs_common, **kwargs_encoder) + kwargs_decoder = kwargs_common.copy() + kwargs_encoder = kwargs_common.copy() + kwargs_encoder.update( + { + argument[len("encoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("encoder_") + } + ) + kwargs_decoder.update( + { + argument[len("decoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("decoder_") + } + ) # Encode if needed (training, first prediction pass) encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) if encoder_hidden_states is None: encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) - encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state + encoder_hidden_states = encoder_outputs[ + 0 + ] # output the last layer hidden state else: encoder_outputs = () # Decode kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states - kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None) + kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get( + "attention_mask", None + ) decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder) return decoder_outputs + encoder_outputs @@ -235,6 +249,7 @@ class Model2Model(PreTrainedEncoderDecoder): decoder = BertForMaskedLM(config) model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder) """ + def __init__(self, *args, **kwargs): super(Model2Model, self).__init__(*args, **kwargs) self.tie_weights()