Add Flaubert OnnxConfig to Transformers (#16279)
* Add Flaubert to ONNX to make it available for conversion. * Fixed features for FlauBERT. fixup command remove flaubert to docs list. Co-authored-by: ChainYo <t.chaigneau.tc@gmail.com>
This commit is contained in:
@@ -52,6 +52,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- Data2VecText
|
- Data2VecText
|
||||||
- DistilBERT
|
- DistilBERT
|
||||||
- ELECTRA
|
- ELECTRA
|
||||||
|
- FlauBERT
|
||||||
- GPT Neo
|
- GPT Neo
|
||||||
- I-BERT
|
- I-BERT
|
||||||
- LayoutLM
|
- LayoutLM
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_tf_available, is_torch_available
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig"],
|
"configuration_flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig", "FlaubertOnnxConfig"],
|
||||||
"tokenization_flaubert": ["FlaubertTokenizer"],
|
"tokenization_flaubert": ["FlaubertTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
|
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig, FlaubertOnnxConfig
|
||||||
from .tokenization_flaubert import FlaubertTokenizer
|
from .tokenization_flaubert import FlaubertTokenizer
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
|||||||
@@ -14,6 +14,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Flaubert configuration, based on XLM."""
|
""" Flaubert configuration, based on XLM."""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..xlm.configuration_xlm import XLMConfig
|
from ..xlm.configuration_xlm import XLMConfig
|
||||||
|
|
||||||
@@ -137,3 +141,14 @@ class FlaubertConfig(XLMConfig):
|
|||||||
self.layerdrop = layerdrop
|
self.layerdrop = layerdrop
|
||||||
self.pre_norm = pre_norm
|
self.pre_norm = pre_norm
|
||||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaubertOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from ..models.bert import BertOnnxConfig
|
|||||||
from ..models.camembert import CamembertOnnxConfig
|
from ..models.camembert import CamembertOnnxConfig
|
||||||
from ..models.distilbert import DistilBertOnnxConfig
|
from ..models.distilbert import DistilBertOnnxConfig
|
||||||
from ..models.electra import ElectraOnnxConfig
|
from ..models.electra import ElectraOnnxConfig
|
||||||
|
from ..models.flaubert import FlaubertOnnxConfig
|
||||||
from ..models.gpt2 import GPT2OnnxConfig
|
from ..models.gpt2 import GPT2OnnxConfig
|
||||||
from ..models.gpt_neo import GPTNeoOnnxConfig
|
from ..models.gpt_neo import GPTNeoOnnxConfig
|
||||||
from ..models.ibert import IBertOnnxConfig
|
from ..models.ibert import IBertOnnxConfig
|
||||||
@@ -179,6 +180,15 @@ class FeaturesManager:
|
|||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=DistilBertOnnxConfig,
|
onnx_config_cls=DistilBertOnnxConfig,
|
||||||
),
|
),
|
||||||
|
"flaubert": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"masked-lm",
|
||||||
|
"causal-lm",
|
||||||
|
"sequence-classification",
|
||||||
|
"token-classification",
|
||||||
|
"question-answering",
|
||||||
|
onnx_config_cls=FlaubertOnnxConfig,
|
||||||
|
),
|
||||||
"marian": supported_features_mapping(
|
"marian": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"default-with-past",
|
"default-with-past",
|
||||||
|
|||||||
Reference in New Issue
Block a user