enable low-precision pipeline (#31625)

* enable low-precision pipeline

* fix parameter for ASR

* reformat

* fix asr bug

* fix bug for zero-shot

* add dtype check

* rm useless comments

* add np.float16 check

* Update src/transformers/pipelines/image_classification.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/pipelines/token_classification.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* fix comments

* fix asr check

* make fixup

* No more need for is_torch_available()

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Co-authored-by: Matt <rocketknight1@gmail.com>
This commit is contained in:
jiqing-feng
2024-09-21 07:43:30 +08:00
committed by GitHub
parent 7b2b536a81
commit 49a0bef4c1
6 changed files with 143 additions and 4 deletions

View File

@@ -21,11 +21,22 @@ from transformers import (
ZeroShotClassificationPipeline,
pipeline,
)
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
slow,
)
from .test_pipelines_common import ANY
if is_torch_available():
import torch
# These 2 model types require different inputs than those of the usual text models.
_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}
@@ -176,6 +187,48 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase):
},
)
@require_torch
def test_small_model_pt_fp16(self):
zero_shot_classifier = pipeline(
"zero-shot-classification",
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
framework="pt",
torch_dtype=torch.float16,
)
outputs = zero_shot_classifier(
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
)
self.assertEqual(
nested_simplify(outputs),
{
"sequence": "Who are you voting for in 2020?",
"labels": ["science", "public health", "politics"],
"scores": [0.333, 0.333, 0.333],
},
)
@require_torch
def test_small_model_pt_bf16(self):
zero_shot_classifier = pipeline(
"zero-shot-classification",
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
framework="pt",
torch_dtype=torch.bfloat16,
)
outputs = zero_shot_classifier(
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
)
self.assertEqual(
nested_simplify(outputs),
{
"sequence": "Who are you voting for in 2020?",
"labels": ["science", "public health", "politics"],
"scores": [0.333, 0.333, 0.333],
},
)
@require_tf
def test_small_model_tf(self):
zero_shot_classifier = pipeline(