Fix pipeline tests for Roberta-like tokenizers (#19365)
* Fix pipeline tests for Roberta-like tokenizers * Fix fix
This commit is contained in:
@@ -37,8 +37,6 @@ from transformers import (
|
|||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
DistilBertForSequenceClassification,
|
DistilBertForSequenceClassification,
|
||||||
IBertConfig,
|
|
||||||
RobertaConfig,
|
|
||||||
TextClassificationPipeline,
|
TextClassificationPipeline,
|
||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
pipeline,
|
pipeline,
|
||||||
@@ -71,6 +69,16 @@ from test_module.custom_pipeline import PairClassificationPipeline # noqa E402
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ROBERTA_EMBEDDING_ADJUSMENT_CONFIGS = [
|
||||||
|
"CamembertConfig",
|
||||||
|
"IBertConfig",
|
||||||
|
"LongformerConfig",
|
||||||
|
"MarkupLMConfig",
|
||||||
|
"RobertaConfig",
|
||||||
|
"XLMRobertaConfig",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_checkpoint_from_architecture(architecture):
|
def get_checkpoint_from_architecture(architecture):
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(architecture.__module__)
|
module = importlib.import_module(architecture.__module__)
|
||||||
@@ -194,7 +202,7 @@ class PipelineTestCaseMeta(type):
|
|||||||
try:
|
try:
|
||||||
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
|
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
|
||||||
# XLNet actually defines it as -1.
|
# XLNet actually defines it as -1.
|
||||||
if isinstance(model.config, (RobertaConfig, IBertConfig)):
|
if model.config.__class__.__name__ in ROBERTA_EMBEDDING_ADJUSMENT_CONFIGS:
|
||||||
tokenizer.model_max_length = model.config.max_position_embeddings - 2
|
tokenizer.model_max_length = model.config.max_position_embeddings - 2
|
||||||
elif (
|
elif (
|
||||||
hasattr(model.config, "max_position_embeddings")
|
hasattr(model.config, "max_position_embeddings")
|
||||||
|
|||||||
Reference in New Issue
Block a user