From 367235ee52537ff7cada5e1c5c41cdd78731f092 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 31 Aug 2020 16:16:47 -0400 Subject: [PATCH] Bart can make decoder_input_ids from labels (#6758) --- src/transformers/modeling_bart.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 138b1a2f48..45b40554cd 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -58,8 +58,8 @@ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ "facebook/bart-large-cnn", "facebook/bart-large-xsum", "facebook/mbart-large-en-ro", - # See all BART models at https://huggingface.co/models?filter=bart ] +# This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart BART_START_DOCSTRING = r""" @@ -1045,6 +1045,8 @@ class BartForConditionalGeneration(PretrainedBartModel): if labels is not None: use_cache = False + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) outputs = self.model( input_ids,