@@ -28,11 +28,14 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
|
||||
Handles arguments for token classification.
|
||||
"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, inputs: Union[str, List[str]], **kwargs):
|
||||
|
||||
if args is not None and len(args) > 0:
|
||||
inputs = list(args)
|
||||
if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
|
||||
inputs = list(inputs)
|
||||
batch_size = len(inputs)
|
||||
elif isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batch_size = 1
|
||||
else:
|
||||
raise ValueError("At least one input is required.")
|
||||
|
||||
@@ -137,11 +140,11 @@ class TokenClassificationPipeline(Pipeline):
|
||||
Only exists if the offsets are available within the tokenizer
|
||||
"""
|
||||
|
||||
inputs, offset_mappings = self._args_parser(inputs, **kwargs)
|
||||
_inputs, offset_mappings = self._args_parser(inputs, **kwargs)
|
||||
|
||||
answers = []
|
||||
|
||||
for i, sentence in enumerate(inputs):
|
||||
for i, sentence in enumerate(_inputs):
|
||||
|
||||
# Manage correct placement of the tensors
|
||||
with self.device_placement():
|
||||
|
||||
Reference in New Issue
Block a user