Raise error if stride is too high in TokenClassificationPipeline (#22942)
* Raise error if `stride` is too high * Clarify use of `stride`
This commit is contained in:
@@ -69,7 +69,8 @@ class AggregationStrategy(ExplicitEnum):
|
|||||||
stride (`int`, *optional*):
|
stride (`int`, *optional*):
|
||||||
If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
|
If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
|
||||||
model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`. The
|
model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`. The
|
||||||
value of this argument defines the number of overlapping tokens between chunks.
|
value of this argument defines the number of overlapping tokens between chunks. In other words, the model
|
||||||
|
will shift forward by `tokenizer.model_max_length - stride` tokens each step.
|
||||||
aggregation_strategy (`str`, *optional*, defaults to `"none"`):
|
aggregation_strategy (`str`, *optional*, defaults to `"none"`):
|
||||||
The strategy to fuse (or not) tokens based on the model prediction.
|
The strategy to fuse (or not) tokens based on the model prediction.
|
||||||
|
|
||||||
@@ -191,6 +192,10 @@ class TokenClassificationPipeline(ChunkPipeline):
|
|||||||
if ignore_labels is not None:
|
if ignore_labels is not None:
|
||||||
postprocess_params["ignore_labels"] = ignore_labels
|
postprocess_params["ignore_labels"] = ignore_labels
|
||||||
if stride is not None:
|
if stride is not None:
|
||||||
|
if stride >= self.tokenizer.model_max_length:
|
||||||
|
raise ValueError(
|
||||||
|
"`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)"
|
||||||
|
)
|
||||||
if aggregation_strategy == AggregationStrategy.NONE:
|
if aggregation_strategy == AggregationStrategy.NONE:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`stride` was provided to process all the text but `aggregation_strategy="
|
"`stride` was provided to process all the text but `aggregation_strategy="
|
||||||
|
|||||||
Reference in New Issue
Block a user