Fixing NER pipeline for list inputs. (#10184)

Fixes #10168
This commit is contained in:
Nicolas Patry
2021-02-15 12:22:45 +01:00
committed by GitHub
parent 587197dcd2
commit 900daec24e
2 changed files with 60 additions and 20 deletions

View File

@@ -28,11 +28,14 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
Handles arguments for token classification.
"""
def __call__(self, *args, **kwargs):
def __call__(self, inputs: Union[str, List[str]], **kwargs):
if args is not None and len(args) > 0:
inputs = list(args)
if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
inputs = list(inputs)
batch_size = len(inputs)
elif isinstance(inputs, str):
inputs = [inputs]
batch_size = 1
else:
raise ValueError("At least one input is required.")
@@ -137,11 +140,11 @@ class TokenClassificationPipeline(Pipeline):
Only exists if the offsets are available within the tokenizer
"""
inputs, offset_mappings = self._args_parser(inputs, **kwargs)
_inputs, offset_mappings = self._args_parser(inputs, **kwargs)
answers = []
for i, sentence in enumerate(inputs):
for i, sentence in enumerate(_inputs):
# Manage correct placement of the tensors
with self.device_placement():