New BartModel (#2745)
* Results same as fairseq * Wrote a ton of tests * Struggled with api signatures * added some docs
This commit is contained in:
@@ -236,42 +236,6 @@ class PreTrainedEncoderDecoder(nn.Module):
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
@staticmethod
|
||||
def prepare_model_kwargs(**kwargs):
|
||||
""" Prepare the encoder and decoder's keyword arguments.
|
||||
|
||||
Keyword arguments come in 3 flavors:
|
||||
- encoder-specific (prefixed by `encoder_`)
|
||||
- decoder-specific (prefixed by `decoder_`)
|
||||
- those that apply to the model as whole.
|
||||
|
||||
We let the specific kwargs override the common ones in case of
|
||||
conflict.
|
||||
"""
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
||||
}
|
||||
decoder_kwargs = kwargs_common.copy()
|
||||
encoder_kwargs = kwargs_common.copy()
|
||||
encoder_kwargs.update(
|
||||
{
|
||||
argument[len("encoder_") :]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("encoder_")
|
||||
}
|
||||
)
|
||||
decoder_kwargs.update(
|
||||
{
|
||||
argument[len("decoder_") :]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
)
|
||||
decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None)
|
||||
return encoder_kwargs, decoder_kwargs
|
||||
|
||||
|
||||
class Model2Model(PreTrainedEncoderDecoder):
|
||||
r"""
|
||||
|
||||
Reference in New Issue
Block a user