Raise error if stride is too high in TokenClassificationPipeline (#22942)
* Raise error if `stride` is too high * Clarify use of `stride`
This commit is contained in:
@@ -69,7 +69,8 @@ class AggregationStrategy(ExplicitEnum):
|
||||
stride (`int`, *optional*):
|
||||
If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
|
||||
model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`. The
|
||||
value of this argument defines the number of overlapping tokens between chunks.
|
||||
value of this argument defines the number of overlapping tokens between chunks. In other words, the model
|
||||
will shift forward by `tokenizer.model_max_length - stride` tokens each step.
|
||||
aggregation_strategy (`str`, *optional*, defaults to `"none"`):
|
||||
The strategy to fuse (or not) tokens based on the model prediction.
|
||||
|
||||
@@ -191,6 +192,10 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
if ignore_labels is not None:
|
||||
postprocess_params["ignore_labels"] = ignore_labels
|
||||
if stride is not None:
|
||||
if stride >= self.tokenizer.model_max_length:
|
||||
raise ValueError(
|
||||
"`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)"
|
||||
)
|
||||
if aggregation_strategy == AggregationStrategy.NONE:
|
||||
raise ValueError(
|
||||
"`stride` was provided to process all the text but `aggregation_strategy="
|
||||
|
||||
Reference in New Issue
Block a user