From fe65657de112531a2d5303491f245f9e7534ae8d Mon Sep 17 00:00:00 2001 From: bofeng huang Date: Thu, 29 Dec 2022 08:19:25 +0100 Subject: [PATCH] Fix FP16 inference in TextGenerationPipeline (#20913) * add torch_dtype attribute to Pipeline * Use torch_dtype to cast input tensor type in AutomaticSpeechRecognitionPipeline * Fix code quality * Add TextGenerationPipeline fp16 test * Fix code quality * Remove useless require in tests Co-authored-by: Nicolas Patry Co-authored-by: Nicolas Patry --- .../pipelines/automatic_speech_recognition.py | 12 ++++++------ src/transformers/pipelines/base.py | 2 ++ tests/pipelines/test_pipelines_text_generation.py | 8 ++++++++ 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 9e4f45fd79..ebbfcfa4c5 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -242,8 +242,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): preprocess_params["stride_length_s"] = kwargs["stride_length_s"] if "ignore_warning" in kwargs: preprocess_params["ignore_warning"] = kwargs["ignore_warning"] - if "torch_dtype" in kwargs: - preprocess_params["dtype"] = kwargs["torch_dtype"] postprocess_params = {} if "decoder_kwargs" in kwargs: @@ -253,7 +251,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): return preprocess_params, {}, postprocess_params - def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False, dtype=None): + def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False): if isinstance(inputs, str): if inputs.startswith("http://") or inputs.startswith("https://"): # We need to actually check for a real protocol, otherwise it's impossible to use a local file @@ -336,14 +334,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): raise ValueError("Chunk length must be superior to stride length") # make sure that - for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right, dtype): + for item in chunk_iter( + inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype + ): yield item else: processed = self.feature_extractor( inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" ) - if dtype is not None: - processed = processed.to(dtype=dtype) + if self.torch_dtype is not None: + processed = processed.to(dtype=self.torch_dtype) if stride is not None: if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values(): raise ValueError("Stride is only usable with CTC models, try removing it") diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 4205ef2eb2..a0f9b3bd02 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -748,6 +748,7 @@ class Pipeline(_ScikitCompat): task: str = "", args_parser: ArgumentHandler = None, device: Union[int, str, "torch.device"] = -1, + torch_dtype: Optional[Union[str, "torch.dtype"]] = None, binary_output: bool = False, **kwargs, ): @@ -771,6 +772,7 @@ class Pipeline(_ScikitCompat): self.device = torch.device(f"cuda:{device}") else: self.device = device + self.torch_dtype = torch_dtype self.binary_output = binary_output # Special handling diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index c0aee8b2db..4de6f878dd 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -300,3 +300,11 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM } ], ) + + @require_torch + @require_torch_gpu + def test_small_model_fp16(self): + import torch + + pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device=0, torch_dtype=torch.float16) + pipe("This is a test")