enable low-precision pipeline (#31625)
* enable low-precision pipeline * fix parameter for ASR * reformat * fix asr bug * fix bug for zero-shot * add dtype check * rm useless comments * add np.float16 check * Update src/transformers/pipelines/image_classification.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/pipelines/token_classification.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * fix comments * fix asr check * make fixup * No more need for is_torch_available() --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> Co-authored-by: Matt <rocketknight1@gmail.com>
This commit is contained in:
@@ -21,11 +21,22 @@ from transformers import (
|
||||
ZeroShotClassificationPipeline,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
is_torch_available,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
require_torch,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
# These 2 model types require different inputs than those of the usual text models.
|
||||
_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}
|
||||
|
||||
@@ -176,6 +187,48 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_fp16(self):
|
||||
zero_shot_classifier = pipeline(
|
||||
"zero-shot-classification",
|
||||
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
|
||||
framework="pt",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
outputs = zero_shot_classifier(
|
||||
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
{
|
||||
"sequence": "Who are you voting for in 2020?",
|
||||
"labels": ["science", "public health", "politics"],
|
||||
"scores": [0.333, 0.333, 0.333],
|
||||
},
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_bf16(self):
|
||||
zero_shot_classifier = pipeline(
|
||||
"zero-shot-classification",
|
||||
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
|
||||
framework="pt",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
outputs = zero_shot_classifier(
|
||||
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
{
|
||||
"sequence": "Who are you voting for in 2020?",
|
||||
"labels": ["science", "public health", "politics"],
|
||||
"scores": [0.333, 0.333, 0.333],
|
||||
},
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
zero_shot_classifier = pipeline(
|
||||
|
||||
Reference in New Issue
Block a user