From 7e7f62bfa72ca03e9f16285dad182f7c57cd8cab Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 5 Oct 2022 17:48:14 -0400 Subject: [PATCH] Fix pipeline tests for Roberta-like tokenizers (#19365) * Fix pipeline tests for Roberta-like tokenizers * Fix fix --- tests/pipelines/test_pipelines_common.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index ea32f5cac4..34684186b5 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -37,8 +37,6 @@ from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, DistilBertForSequenceClassification, - IBertConfig, - RobertaConfig, TextClassificationPipeline, TFAutoModelForSequenceClassification, pipeline, @@ -71,6 +69,16 @@ from test_module.custom_pipeline import PairClassificationPipeline # noqa E402 logger = logging.getLogger(__name__) +ROBERTA_EMBEDDING_ADJUSMENT_CONFIGS = [ + "CamembertConfig", + "IBertConfig", + "LongformerConfig", + "MarkupLMConfig", + "RobertaConfig", + "XLMRobertaConfig", +] + + def get_checkpoint_from_architecture(architecture): try: module = importlib.import_module(architecture.__module__) @@ -194,7 +202,7 @@ class PipelineTestCaseMeta(type): try: tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint) # XLNet actually defines it as -1. - if isinstance(model.config, (RobertaConfig, IBertConfig)): + if model.config.__class__.__name__ in ROBERTA_EMBEDDING_ADJUSMENT_CONFIGS: tokenizer.model_max_length = model.config.max_position_embeddings - 2 elif ( hasattr(model.config, "max_position_embeddings")