Bart can make decoder_input_ids from labels (#6758)
This commit is contained in:
@@ -58,8 +58,8 @@ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/bart-large-cnn",
|
||||
"facebook/bart-large-xsum",
|
||||
"facebook/mbart-large-en-ro",
|
||||
# See all BART models at https://huggingface.co/models?filter=bart
|
||||
]
|
||||
# This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart
|
||||
|
||||
|
||||
BART_START_DOCSTRING = r"""
|
||||
@@ -1045,6 +1045,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
|
||||
if labels is not None:
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
|
||||
Reference in New Issue
Block a user