Bart can make decoder_input_ids from labels (#6758)
This commit is contained in:
@@ -58,8 +58,8 @@ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
"facebook/bart-large-cnn",
|
"facebook/bart-large-cnn",
|
||||||
"facebook/bart-large-xsum",
|
"facebook/bart-large-xsum",
|
||||||
"facebook/mbart-large-en-ro",
|
"facebook/mbart-large-en-ro",
|
||||||
# See all BART models at https://huggingface.co/models?filter=bart
|
|
||||||
]
|
]
|
||||||
|
# This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart
|
||||||
|
|
||||||
|
|
||||||
BART_START_DOCSTRING = r"""
|
BART_START_DOCSTRING = r"""
|
||||||
@@ -1045,6 +1045,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
if decoder_input_ids is None:
|
||||||
|
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
|
||||||
|
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user