diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index e2a4251f83..5ac16faf89 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -791,7 +791,7 @@ class Pipeline(_ScikitCompat): elif isinstance(inputs, tuple): return tuple([self._ensure_tensor_on_device(item, device) for item in inputs]) elif isinstance(inputs, torch.Tensor): - return inputs.to(self.device) + return inputs.to(device) else: return inputs diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index bd581fb7e9..fc3fce2366 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -204,9 +204,10 @@ class TokenClassificationPipeline(Pipeline): offset_mapping = model_inputs.pop("offset_mapping", None) sentence = model_inputs.pop("sentence") if self.framework == "tf": - outputs = self.model(model_inputs.data)[0][0].numpy() + outputs = self.model(model_inputs.data)[0][0] else: - outputs = self.model(**model_inputs)[0][0].numpy() + outputs = self.model(**model_inputs)[0][0] + return { "outputs": outputs, "special_tokens_mask": special_tokens_mask, @@ -216,7 +217,7 @@ class TokenClassificationPipeline(Pipeline): } def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE): - outputs = model_outputs["outputs"] + outputs = model_outputs["outputs"].numpy() sentence = model_outputs["sentence"] input_ids = model_outputs["input_ids"][0] offset_mapping = model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None diff --git a/tests/test_pipelines_token_classification.py b/tests/test_pipelines_token_classification.py index 514dc8b94f..d94e4cc7f8 100644 --- a/tests/test_pipelines_token_classification.py +++ b/tests/test_pipelines_token_classification.py @@ -25,7 +25,14 @@ from transformers import ( pipeline, ) from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler -from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow +from transformers.testing_utils import ( + is_pipeline_test, + nested_simplify, + require_tf, + require_torch, + require_torch_gpu, + slow, +) from .test_pipelines_common import ANY, PipelineTestCaseMeta @@ -246,6 +253,19 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ], ) + @require_torch_gpu + @slow + def test_gpu(self): + sentence = "This is dummy sentence" + ner = pipeline( + "token-classification", + device=0, + aggregation_strategy=AggregationStrategy.SIMPLE, + ) + + output = ner(sentence) + self.assertEqual(nested_simplify(output), []) + @require_torch @slow def test_dbmdz_english(self):