[BART] to each its own config + make BART compatible w/ Pipelines

cc @sshleifer
This commit is contained in:
Julien Chaumond
2020-03-02 18:56:17 -05:00
parent 6b1558bad8
commit eec5ec8071
2 changed files with 4 additions and 4 deletions

View File

@@ -22,10 +22,9 @@ from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_bart_large_url = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json"
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bart-large": _bart_large_url, "bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
"bart-large-mnli": _bart_large_url, # fine as same "bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json", "bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
} }

View File

@@ -28,6 +28,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
from .configuration_bart import BartConfig
from .configuration_distilbert import DistilBertConfig from .configuration_distilbert import DistilBertConfig
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
@@ -427,7 +428,7 @@ class Pipeline(_ScikitCompat):
""" """
args = ["input_ids", "attention_mask"] args = ["input_ids", "attention_mask"]
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig)): if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig)):
args += ["token_type_ids"] args += ["token_type_ids"]
# PR #1548 (CLI) There is an issue with attention_mask # PR #1548 (CLI) There is an issue with attention_mask