change kwargs processing
This commit is contained in:
@@ -114,23 +114,28 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||||
# that apply to the model as a whole.
|
# that apply to the model as a whole.
|
||||||
# We let the specific kwargs override the common ones in case of conflict.
|
# We let the specific kwargs override the common ones in case of conflict.
|
||||||
kwargs_encoder = {
|
|
||||||
argument[len("encoder_"):]: value
|
|
||||||
for argument, value in kwargs.items()
|
|
||||||
if argument.startswith("encoder_")
|
|
||||||
}
|
|
||||||
kwargs_decoder = {
|
|
||||||
argument[len("decoder_"):]: value
|
|
||||||
for argument, value in kwargs.items()
|
|
||||||
if argument.startswith("decoder_")
|
|
||||||
}
|
|
||||||
kwargs_common = {
|
kwargs_common = {
|
||||||
argument: value
|
argument: value
|
||||||
for argument, value in kwargs.items()
|
for argument, value in kwargs.items()
|
||||||
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
|
if not argument.startswith("encoder_")
|
||||||
|
and not argument.startswith("decoder_")
|
||||||
}
|
}
|
||||||
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
|
kwargs_decoder = kwargs_common.copy()
|
||||||
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
|
kwargs_encoder = kwargs_common.copy()
|
||||||
|
kwargs_encoder.update(
|
||||||
|
{
|
||||||
|
argument[len("encoder_") :]: value
|
||||||
|
for argument, value in kwargs.items()
|
||||||
|
if argument.startswith("encoder_")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
kwargs_decoder.update(
|
||||||
|
{
|
||||||
|
argument[len("decoder_") :]: value
|
||||||
|
for argument, value in kwargs.items()
|
||||||
|
if argument.startswith("decoder_")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Load and initialize the encoder and decoder
|
# Load and initialize the encoder and decoder
|
||||||
# The distinction between encoder and decoder at the model level is made
|
# The distinction between encoder and decoder at the model level is made
|
||||||
@@ -185,35 +190,44 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||||
# that apply to the model as whole.
|
# that apply to the model as whole.
|
||||||
# We let the specific kwargs override the common ones in case of conflict.
|
# We let the specific kwargs override the common ones in case of conflict.
|
||||||
kwargs_encoder = {
|
|
||||||
argument[len("encoder_"):]: value
|
|
||||||
for argument, value in kwargs.items()
|
|
||||||
if argument.startswith("encoder_")
|
|
||||||
}
|
|
||||||
kwargs_decoder = {
|
|
||||||
argument[len("decoder_"):]: value
|
|
||||||
for argument, value in kwargs.items()
|
|
||||||
if argument.startswith("decoder_")
|
|
||||||
}
|
|
||||||
kwargs_common = {
|
kwargs_common = {
|
||||||
argument: value
|
argument: value
|
||||||
for argument, value in kwargs.items()
|
for argument, value in kwargs.items()
|
||||||
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
|
if not argument.startswith("encoder_")
|
||||||
|
and not argument.startswith("decoder_")
|
||||||
}
|
}
|
||||||
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
|
kwargs_decoder = kwargs_common.copy()
|
||||||
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
|
kwargs_encoder = kwargs_common.copy()
|
||||||
|
kwargs_encoder.update(
|
||||||
|
{
|
||||||
|
argument[len("encoder_") :]: value
|
||||||
|
for argument, value in kwargs.items()
|
||||||
|
if argument.startswith("encoder_")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
kwargs_decoder.update(
|
||||||
|
{
|
||||||
|
argument[len("decoder_") :]: value
|
||||||
|
for argument, value in kwargs.items()
|
||||||
|
if argument.startswith("decoder_")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Encode if needed (training, first prediction pass)
|
# Encode if needed (training, first prediction pass)
|
||||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||||
if encoder_hidden_states is None:
|
if encoder_hidden_states is None:
|
||||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||||
encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
|
encoder_hidden_states = encoder_outputs[
|
||||||
|
0
|
||||||
|
] # output the last layer hidden state
|
||||||
else:
|
else:
|
||||||
encoder_outputs = ()
|
encoder_outputs = ()
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||||
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
|
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get(
|
||||||
|
"attention_mask", None
|
||||||
|
)
|
||||||
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
|
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
|
||||||
|
|
||||||
return decoder_outputs + encoder_outputs
|
return decoder_outputs + encoder_outputs
|
||||||
@@ -235,6 +249,7 @@ class Model2Model(PreTrainedEncoderDecoder):
|
|||||||
decoder = BertForMaskedLM(config)
|
decoder = BertForMaskedLM(config)
|
||||||
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
|
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Model2Model, self).__init__(*args, **kwargs)
|
super(Model2Model, self).__init__(*args, **kwargs)
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|||||||
Reference in New Issue
Block a user