Fix pipeline tests for Roberta-like tokenizers (#19365)
* Fix pipeline tests for Roberta-like tokenizers * Fix fix
This commit is contained in:
@@ -37,8 +37,6 @@ from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
DistilBertForSequenceClassification,
|
||||
IBertConfig,
|
||||
RobertaConfig,
|
||||
TextClassificationPipeline,
|
||||
TFAutoModelForSequenceClassification,
|
||||
pipeline,
|
||||
@@ -71,6 +69,16 @@ from test_module.custom_pipeline import PairClassificationPipeline # noqa E402
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ROBERTA_EMBEDDING_ADJUSMENT_CONFIGS = [
|
||||
"CamembertConfig",
|
||||
"IBertConfig",
|
||||
"LongformerConfig",
|
||||
"MarkupLMConfig",
|
||||
"RobertaConfig",
|
||||
"XLMRobertaConfig",
|
||||
]
|
||||
|
||||
|
||||
def get_checkpoint_from_architecture(architecture):
|
||||
try:
|
||||
module = importlib.import_module(architecture.__module__)
|
||||
@@ -194,7 +202,7 @@ class PipelineTestCaseMeta(type):
|
||||
try:
|
||||
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
|
||||
# XLNet actually defines it as -1.
|
||||
if isinstance(model.config, (RobertaConfig, IBertConfig)):
|
||||
if model.config.__class__.__name__ in ROBERTA_EMBEDDING_ADJUSMENT_CONFIGS:
|
||||
tokenizer.model_max_length = model.config.max_position_embeddings - 2
|
||||
elif (
|
||||
hasattr(model.config, "max_position_embeddings")
|
||||
|
||||
Reference in New Issue
Block a user