Fixing GPU for token-classification in a better way. (#13856)
Co-authored-by: Pierre Snell <pierre.snell@botpress.com> Co-authored-by: Pierre Snell <pierre.snell@botpress.com>
This commit is contained in:
@@ -25,7 +25,14 @@ from transformers import (
|
||||
pipeline,
|
||||
)
|
||||
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
|
||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
@@ -246,6 +253,19 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_gpu(self):
|
||||
sentence = "This is dummy sentence"
|
||||
ner = pipeline(
|
||||
"token-classification",
|
||||
device=0,
|
||||
aggregation_strategy=AggregationStrategy.SIMPLE,
|
||||
)
|
||||
|
||||
output = ner(sentence)
|
||||
self.assertEqual(nested_simplify(output), [])
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_dbmdz_english(self):
|
||||
|
||||
Reference in New Issue
Block a user