Fix FP16 inference in TextGenerationPipeline (#20913)
* add torch_dtype attribute to Pipeline * Use torch_dtype to cast input tensor type in AutomaticSpeechRecognitionPipeline * Fix code quality * Add TextGenerationPipeline fp16 test * Fix code quality * Remove useless require in tests Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@@ -242,8 +242,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
|
||||
if "ignore_warning" in kwargs:
|
||||
preprocess_params["ignore_warning"] = kwargs["ignore_warning"]
|
||||
if "torch_dtype" in kwargs:
|
||||
preprocess_params["dtype"] = kwargs["torch_dtype"]
|
||||
|
||||
postprocess_params = {}
|
||||
if "decoder_kwargs" in kwargs:
|
||||
@@ -253,7 +251,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False, dtype=None):
|
||||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
|
||||
if isinstance(inputs, str):
|
||||
if inputs.startswith("http://") or inputs.startswith("https://"):
|
||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||
@@ -336,14 +334,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
raise ValueError("Chunk length must be superior to stride length")
|
||||
|
||||
# make sure that
|
||||
for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right, dtype):
|
||||
for item in chunk_iter(
|
||||
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
|
||||
):
|
||||
yield item
|
||||
else:
|
||||
processed = self.feature_extractor(
|
||||
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
if dtype is not None:
|
||||
processed = processed.to(dtype=dtype)
|
||||
if self.torch_dtype is not None:
|
||||
processed = processed.to(dtype=self.torch_dtype)
|
||||
if stride is not None:
|
||||
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
||||
raise ValueError("Stride is only usable with CTC models, try removing it")
|
||||
|
||||
@@ -748,6 +748,7 @@ class Pipeline(_ScikitCompat):
|
||||
task: str = "",
|
||||
args_parser: ArgumentHandler = None,
|
||||
device: Union[int, str, "torch.device"] = -1,
|
||||
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
|
||||
binary_output: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -771,6 +772,7 @@ class Pipeline(_ScikitCompat):
|
||||
self.device = torch.device(f"cuda:{device}")
|
||||
else:
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
self.binary_output = binary_output
|
||||
|
||||
# Special handling
|
||||
|
||||
@@ -300,3 +300,11 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
def test_small_model_fp16(self):
|
||||
import torch
|
||||
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device=0, torch_dtype=torch.float16)
|
||||
pipe("This is a test")
|
||||
|
||||
Reference in New Issue
Block a user