enable low-precision pipeline (#31625)

* enable low-precision pipeline

* fix parameter for ASR

* reformat

* fix asr bug

* fix bug for zero-shot

* add dtype check

* rm useless comments

* add np.float16 check

* Update src/transformers/pipelines/image_classification.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/pipelines/token_classification.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* fix comments

* fix asr check

* make fixup

* No more need for is_torch_available()

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Co-authored-by: Matt <rocketknight1@gmail.com>
This commit is contained in:
jiqing-feng
2024-09-21 07:43:30 +08:00
committed by GitHub
parent 7b2b536a81
commit 49a0bef4c1
6 changed files with 143 additions and 4 deletions

View File

@@ -27,6 +27,7 @@ from transformers import (
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
@@ -38,6 +39,10 @@ from transformers.testing_utils import (
from .test_pipelines_common import ANY
if is_torch_available():
import torch
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
# These 2 model types require different inputs than those of the usual text models.
@@ -841,6 +846,36 @@ class TokenClassificationPipelineTests(unittest.TestCase):
],
)
@require_torch
def test_small_model_pt_fp16(self):
model_name = "hf-internal-testing/tiny-bert-for-token-classification"
token_classifier = pipeline(
task="token-classification", model=model_name, framework="pt", torch_dtype=torch.float16
)
outputs = token_classifier("This is a test !")
self.assertEqual(
nested_simplify(outputs),
[
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
],
)
@require_torch
def test_small_model_pt_bf16(self):
model_name = "hf-internal-testing/tiny-bert-for-token-classification"
token_classifier = pipeline(
task="token-classification", model=model_name, framework="pt", torch_dtype=torch.bfloat16
)
outputs = token_classifier("This is a test !")
self.assertEqual(
nested_simplify(outputs),
[
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
],
)
@require_torch
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"