Fixing GPU for token-classification in a better way. (#13856)
Co-authored-by: Pierre Snell <pierre.snell@botpress.com> Co-authored-by: Pierre Snell <pierre.snell@botpress.com>
This commit is contained in:
@@ -791,7 +791,7 @@ class Pipeline(_ScikitCompat):
|
||||
elif isinstance(inputs, tuple):
|
||||
return tuple([self._ensure_tensor_on_device(item, device) for item in inputs])
|
||||
elif isinstance(inputs, torch.Tensor):
|
||||
return inputs.to(self.device)
|
||||
return inputs.to(device)
|
||||
else:
|
||||
return inputs
|
||||
|
||||
|
||||
@@ -204,9 +204,10 @@ class TokenClassificationPipeline(Pipeline):
|
||||
offset_mapping = model_inputs.pop("offset_mapping", None)
|
||||
sentence = model_inputs.pop("sentence")
|
||||
if self.framework == "tf":
|
||||
outputs = self.model(model_inputs.data)[0][0].numpy()
|
||||
outputs = self.model(model_inputs.data)[0][0]
|
||||
else:
|
||||
outputs = self.model(**model_inputs)[0][0].numpy()
|
||||
outputs = self.model(**model_inputs)[0][0]
|
||||
|
||||
return {
|
||||
"outputs": outputs,
|
||||
"special_tokens_mask": special_tokens_mask,
|
||||
@@ -216,7 +217,7 @@ class TokenClassificationPipeline(Pipeline):
|
||||
}
|
||||
|
||||
def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE):
|
||||
outputs = model_outputs["outputs"]
|
||||
outputs = model_outputs["outputs"].numpy()
|
||||
sentence = model_outputs["sentence"]
|
||||
input_ids = model_outputs["input_ids"][0]
|
||||
offset_mapping = model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
|
||||
|
||||
@@ -25,7 +25,14 @@ from transformers import (
|
||||
pipeline,
|
||||
)
|
||||
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
|
||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
@@ -246,6 +253,19 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_gpu(self):
|
||||
sentence = "This is dummy sentence"
|
||||
ner = pipeline(
|
||||
"token-classification",
|
||||
device=0,
|
||||
aggregation_strategy=AggregationStrategy.SIMPLE,
|
||||
)
|
||||
|
||||
output = ner(sentence)
|
||||
self.assertEqual(nested_simplify(output), [])
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_dbmdz_english(self):
|
||||
|
||||
Reference in New Issue
Block a user