From eec5ec807135ae61fa2266f3c7ad947cc207abda Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 2 Mar 2020 18:56:17 -0500 Subject: [PATCH] [BART] to each its own config + make BART compatible w/ Pipelines cc @sshleifer --- src/transformers/configuration_bart.py | 5 ++--- src/transformers/pipelines.py | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/configuration_bart.py b/src/transformers/configuration_bart.py index 7eb3bd7fe8..a80e743c42 100644 --- a/src/transformers/configuration_bart.py +++ b/src/transformers/configuration_bart.py @@ -22,10 +22,9 @@ from .configuration_utils import PretrainedConfig logger = logging.getLogger(__name__) -_bart_large_url = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json" BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "bart-large": _bart_large_url, - "bart-large-mnli": _bart_large_url, # fine as same + "bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json", + "bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json", "bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json", } diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index cd7b9ca55d..7b73a70d4d 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -28,6 +28,7 @@ from typing import Dict, List, Optional, Tuple, Union import numpy as np from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig +from .configuration_bart import BartConfig from .configuration_distilbert import DistilBertConfig from .configuration_roberta import RobertaConfig from .configuration_utils import PretrainedConfig @@ -427,7 +428,7 @@ class Pipeline(_ScikitCompat): """ args = ["input_ids", "attention_mask"] - if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig)): + if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig)): args += ["token_type_ids"] # PR #1548 (CLI) There is an issue with attention_mask