From b5f06d6c59a183442663c67a33545a9c5c778ef2 Mon Sep 17 00:00:00 2001 From: Connor Boyle Date: Mon, 24 Apr 2023 06:27:49 -0700 Subject: [PATCH] Raise error if `stride` is too high in `TokenClassificationPipeline` (#22942) * Raise error if `stride` is too high * Clarify use of `stride` --- src/transformers/pipelines/token_classification.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index 1a2a96b398..7698b37f31 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -69,7 +69,8 @@ class AggregationStrategy(ExplicitEnum): stride (`int`, *optional*): If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`. The - value of this argument defines the number of overlapping tokens between chunks. + value of this argument defines the number of overlapping tokens between chunks. In other words, the model + will shift forward by `tokenizer.model_max_length - stride` tokens each step. aggregation_strategy (`str`, *optional*, defaults to `"none"`): The strategy to fuse (or not) tokens based on the model prediction. @@ -191,6 +192,10 @@ class TokenClassificationPipeline(ChunkPipeline): if ignore_labels is not None: postprocess_params["ignore_labels"] = ignore_labels if stride is not None: + if stride >= self.tokenizer.model_max_length: + raise ValueError( + "`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)" + ) if aggregation_strategy == AggregationStrategy.NONE: raise ValueError( "`stride` was provided to process all the text but `aggregation_strategy="