Fixing GPU for token-classification in a better way. (#13856)

Co-authored-by:  Pierre Snell <pierre.snell@botpress.com>

Co-authored-by: Pierre Snell <pierre.snell@botpress.com>
This commit is contained in:
Nicolas Patry
2021-10-06 04:44:31 +02:00
committed by GitHub
parent 7d83655da9
commit e7b16f33ae
3 changed files with 26 additions and 5 deletions

View File

@@ -25,7 +25,14 @@ from transformers import (
pipeline,
)
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_tf,
require_torch,
require_torch_gpu,
slow,
)
from .test_pipelines_common import ANY, PipelineTestCaseMeta
@@ -246,6 +253,19 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
],
)
@require_torch_gpu
@slow
def test_gpu(self):
sentence = "This is dummy sentence"
ner = pipeline(
"token-classification",
device=0,
aggregation_strategy=AggregationStrategy.SIMPLE,
)
output = ner(sentence)
self.assertEqual(nested_simplify(output), [])
@require_torch
@slow
def test_dbmdz_english(self):