[BART] to each its own config + make BART compatible w/ Pipelines
cc @sshleifer
This commit is contained in:
@@ -22,10 +22,9 @@ from .configuration_utils import PretrainedConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_bart_large_url = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json"
|
||||
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"bart-large": _bart_large_url,
|
||||
"bart-large-mnli": _bart_large_url, # fine as same
|
||||
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
|
||||
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
|
||||
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
|
||||
from .configuration_bart import BartConfig
|
||||
from .configuration_distilbert import DistilBertConfig
|
||||
from .configuration_roberta import RobertaConfig
|
||||
from .configuration_utils import PretrainedConfig
|
||||
@@ -427,7 +428,7 @@ class Pipeline(_ScikitCompat):
|
||||
"""
|
||||
args = ["input_ids", "attention_mask"]
|
||||
|
||||
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig)):
|
||||
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig)):
|
||||
args += ["token_type_ids"]
|
||||
|
||||
# PR #1548 (CLI) There is an issue with attention_mask
|
||||
|
||||
Reference in New Issue
Block a user