enable low-precision pipeline (#31625)
* enable low-precision pipeline * fix parameter for ASR * reformat * fix asr bug * fix bug for zero-shot * add dtype check * rm useless comments * add np.float16 check * Update src/transformers/pipelines/image_classification.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/pipelines/token_classification.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * fix comments * fix asr check * make fixup * No more need for is_torch_available() --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> Co-authored-by: Matt <rocketknight1@gmail.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from transformers import (
|
||||
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
is_torch_available,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
require_torch,
|
||||
@@ -38,6 +39,10 @@ from transformers.testing_utils import (
|
||||
from .test_pipelines_common import ANY
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
|
||||
|
||||
# These 2 model types require different inputs than those of the usual text models.
|
||||
@@ -841,6 +846,36 @@ class TokenClassificationPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_fp16(self):
|
||||
model_name = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||
token_classifier = pipeline(
|
||||
task="token-classification", model=model_name, framework="pt", torch_dtype=torch.float16
|
||||
)
|
||||
outputs = token_classifier("This is a test !")
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[
|
||||
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
|
||||
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_bf16(self):
|
||||
model_name = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||
token_classifier = pipeline(
|
||||
task="token-classification", model=model_name, framework="pt", torch_dtype=torch.bfloat16
|
||||
)
|
||||
outputs = token_classifier("This is a test !")
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[
|
||||
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
|
||||
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
|
||||
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
||||
|
||||
Reference in New Issue
Block a user