[BART] to each its own config + make BART compatible w/ Pipelines
cc @sshleifer
This commit is contained in:
@@ -22,10 +22,9 @@ from .configuration_utils import PretrainedConfig
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_bart_large_url = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json"
|
|
||||||
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"bart-large": _bart_large_url,
|
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
|
||||||
"bart-large-mnli": _bart_large_url, # fine as same
|
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
|
||||||
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
|
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
|
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
|
||||||
|
from .configuration_bart import BartConfig
|
||||||
from .configuration_distilbert import DistilBertConfig
|
from .configuration_distilbert import DistilBertConfig
|
||||||
from .configuration_roberta import RobertaConfig
|
from .configuration_roberta import RobertaConfig
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
@@ -427,7 +428,7 @@ class Pipeline(_ScikitCompat):
|
|||||||
"""
|
"""
|
||||||
args = ["input_ids", "attention_mask"]
|
args = ["input_ids", "attention_mask"]
|
||||||
|
|
||||||
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig)):
|
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig)):
|
||||||
args += ["token_type_ids"]
|
args += ["token_type_ids"]
|
||||||
|
|
||||||
# PR #1548 (CLI) There is an issue with attention_mask
|
# PR #1548 (CLI) There is an issue with attention_mask
|
||||||
|
|||||||
Reference in New Issue
Block a user