enable low-precision pipeline (#31625)
* enable low-precision pipeline * fix parameter for ASR * reformat * fix asr bug * fix bug for zero-shot * add dtype check * rm useless comments * add np.float16 check * Update src/transformers/pipelines/image_classification.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/pipelines/token_classification.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * fix comments * fix asr check * make fixup * No more need for is_torch_available() --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> Co-authored-by: Matt <rocketknight1@gmail.com>
This commit is contained in:
@@ -565,7 +565,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
key = "logits" if self.type == "ctc_with_lm" else "tokens"
|
key = "logits" if self.type == "ctc_with_lm" else "tokens"
|
||||||
stride = None
|
stride = None
|
||||||
for outputs in model_outputs:
|
for outputs in model_outputs:
|
||||||
items = outputs[key].numpy()
|
if self.framework == "pt" and outputs[key].dtype in (torch.bfloat16, torch.float16):
|
||||||
|
items = outputs[key].to(torch.float32).numpy()
|
||||||
|
else:
|
||||||
|
items = outputs[key].numpy()
|
||||||
stride = outputs.get("stride", None)
|
stride = outputs.get("stride", None)
|
||||||
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
|
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
|
||||||
total_n, left, right = stride
|
total_n, left, right = stride
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ if is_tf_available():
|
|||||||
|
|
||||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
||||||
|
|
||||||
|
|
||||||
@@ -299,7 +301,11 @@ class TokenClassificationPipeline(ChunkPipeline):
|
|||||||
ignore_labels = ["O"]
|
ignore_labels = ["O"]
|
||||||
all_entities = []
|
all_entities = []
|
||||||
for model_outputs in all_outputs:
|
for model_outputs in all_outputs:
|
||||||
logits = model_outputs["logits"][0].numpy()
|
if self.framework == "pt" and model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16):
|
||||||
|
logits = model_outputs["logits"][0].to(torch.float32).numpy()
|
||||||
|
else:
|
||||||
|
logits = model_outputs["logits"][0].numpy()
|
||||||
|
|
||||||
sentence = all_outputs[0]["sentence"]
|
sentence = all_outputs[0]["sentence"]
|
||||||
input_ids = model_outputs["input_ids"][0]
|
input_ids = model_outputs["input_ids"][0]
|
||||||
offset_mapping = (
|
offset_mapping = (
|
||||||
|
|||||||
@@ -2143,7 +2143,7 @@ def nested_simplify(obj, decimals=3):
|
|||||||
return nested_simplify(obj.numpy().tolist())
|
return nested_simplify(obj.numpy().tolist())
|
||||||
elif isinstance(obj, float):
|
elif isinstance(obj, float):
|
||||||
return round(obj, decimals)
|
return round(obj, decimals)
|
||||||
elif isinstance(obj, (np.int32, np.float32)):
|
elif isinstance(obj, (np.int32, np.float32, np.float16)):
|
||||||
return nested_simplify(obj.item(), decimals)
|
return nested_simplify(obj.item(), decimals)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Not supported: {type(obj)}")
|
raise Exception(f"Not supported: {type(obj)}")
|
||||||
|
|||||||
@@ -167,6 +167,48 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
):
|
):
|
||||||
_ = speech_recognizer(waveform, return_timestamps="char")
|
_ = speech_recognizer(waveform, return_timestamps="char")
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt_fp16(self):
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="facebook/s2t-small-mustc-en-fr-st",
|
||||||
|
tokenizer="facebook/s2t-small-mustc-en-fr-st",
|
||||||
|
framework="pt",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||||
|
output = speech_recognizer(waveform)
|
||||||
|
self.assertEqual(output, {"text": "(Applaudissements)"})
|
||||||
|
output = speech_recognizer(waveform, chunk_length_s=10)
|
||||||
|
self.assertEqual(output, {"text": "(Applaudissements)"})
|
||||||
|
|
||||||
|
# Non CTC models cannot use return_timestamps
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
|
||||||
|
):
|
||||||
|
_ = speech_recognizer(waveform, return_timestamps="char")
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt_bf16(self):
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="facebook/s2t-small-mustc-en-fr-st",
|
||||||
|
tokenizer="facebook/s2t-small-mustc-en-fr-st",
|
||||||
|
framework="pt",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||||
|
output = speech_recognizer(waveform)
|
||||||
|
self.assertEqual(output, {"text": "(Applaudissements)"})
|
||||||
|
output = speech_recognizer(waveform, chunk_length_s=10)
|
||||||
|
self.assertEqual(output, {"text": "(Applaudissements)"})
|
||||||
|
|
||||||
|
# Non CTC models cannot use return_timestamps
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
|
||||||
|
):
|
||||||
|
_ = speech_recognizer(waveform, return_timestamps="char")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
def test_whisper_fp16(self):
|
def test_whisper_fp16(self):
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from transformers import (
|
|||||||
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
|
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
|
is_torch_available,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
require_tf,
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -38,6 +39,10 @@ from transformers.testing_utils import (
|
|||||||
from .test_pipelines_common import ANY
|
from .test_pipelines_common import ANY
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
|
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
|
||||||
|
|
||||||
# These 2 model types require different inputs than those of the usual text models.
|
# These 2 model types require different inputs than those of the usual text models.
|
||||||
@@ -841,6 +846,36 @@ class TokenClassificationPipelineTests(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt_fp16(self):
|
||||||
|
model_name = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||||
|
token_classifier = pipeline(
|
||||||
|
task="token-classification", model=model_name, framework="pt", torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
outputs = token_classifier("This is a test !")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs),
|
||||||
|
[
|
||||||
|
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
|
||||||
|
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt_bf16(self):
|
||||||
|
model_name = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||||
|
token_classifier = pipeline(
|
||||||
|
task="token-classification", model=model_name, framework="pt", torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
outputs = token_classifier("This is a test !")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs),
|
||||||
|
[
|
||||||
|
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
|
||||||
|
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
|
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
|
||||||
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
||||||
|
|||||||
@@ -21,11 +21,22 @@ from transformers import (
|
|||||||
ZeroShotClassificationPipeline,
|
ZeroShotClassificationPipeline,
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
|
from transformers.testing_utils import (
|
||||||
|
is_pipeline_test,
|
||||||
|
is_torch_available,
|
||||||
|
nested_simplify,
|
||||||
|
require_tf,
|
||||||
|
require_torch,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
|
||||||
from .test_pipelines_common import ANY
|
from .test_pipelines_common import ANY
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
# These 2 model types require different inputs than those of the usual text models.
|
# These 2 model types require different inputs than those of the usual text models.
|
||||||
_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}
|
_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}
|
||||||
|
|
||||||
@@ -176,6 +187,48 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt_fp16(self):
|
||||||
|
zero_shot_classifier = pipeline(
|
||||||
|
"zero-shot-classification",
|
||||||
|
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
|
||||||
|
framework="pt",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
outputs = zero_shot_classifier(
|
||||||
|
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs),
|
||||||
|
{
|
||||||
|
"sequence": "Who are you voting for in 2020?",
|
||||||
|
"labels": ["science", "public health", "politics"],
|
||||||
|
"scores": [0.333, 0.333, 0.333],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt_bf16(self):
|
||||||
|
zero_shot_classifier = pipeline(
|
||||||
|
"zero-shot-classification",
|
||||||
|
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
|
||||||
|
framework="pt",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
outputs = zero_shot_classifier(
|
||||||
|
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs),
|
||||||
|
{
|
||||||
|
"sequence": "Who are you voting for in 2020?",
|
||||||
|
"labels": ["science", "public health", "politics"],
|
||||||
|
"scores": [0.333, 0.333, 0.333],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_small_model_tf(self):
|
def test_small_model_tf(self):
|
||||||
zero_shot_classifier = pipeline(
|
zero_shot_classifier = pipeline(
|
||||||
|
|||||||
Reference in New Issue
Block a user