Fix FP16 inference in TextGenerationPipeline (#20913)
* add torch_dtype attribute to Pipeline * Use torch_dtype to cast input tensor type in AutomaticSpeechRecognitionPipeline * Fix code quality * Add TextGenerationPipeline fp16 test * Fix code quality * Remove useless require in tests Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@@ -242,8 +242,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
|
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
|
||||||
if "ignore_warning" in kwargs:
|
if "ignore_warning" in kwargs:
|
||||||
preprocess_params["ignore_warning"] = kwargs["ignore_warning"]
|
preprocess_params["ignore_warning"] = kwargs["ignore_warning"]
|
||||||
if "torch_dtype" in kwargs:
|
|
||||||
preprocess_params["dtype"] = kwargs["torch_dtype"]
|
|
||||||
|
|
||||||
postprocess_params = {}
|
postprocess_params = {}
|
||||||
if "decoder_kwargs" in kwargs:
|
if "decoder_kwargs" in kwargs:
|
||||||
@@ -253,7 +251,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
|
|
||||||
return preprocess_params, {}, postprocess_params
|
return preprocess_params, {}, postprocess_params
|
||||||
|
|
||||||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False, dtype=None):
|
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
|
||||||
if isinstance(inputs, str):
|
if isinstance(inputs, str):
|
||||||
if inputs.startswith("http://") or inputs.startswith("https://"):
|
if inputs.startswith("http://") or inputs.startswith("https://"):
|
||||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||||
@@ -336,14 +334,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
raise ValueError("Chunk length must be superior to stride length")
|
raise ValueError("Chunk length must be superior to stride length")
|
||||||
|
|
||||||
# make sure that
|
# make sure that
|
||||||
for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right, dtype):
|
for item in chunk_iter(
|
||||||
|
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
|
||||||
|
):
|
||||||
yield item
|
yield item
|
||||||
else:
|
else:
|
||||||
processed = self.feature_extractor(
|
processed = self.feature_extractor(
|
||||||
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||||
)
|
)
|
||||||
if dtype is not None:
|
if self.torch_dtype is not None:
|
||||||
processed = processed.to(dtype=dtype)
|
processed = processed.to(dtype=self.torch_dtype)
|
||||||
if stride is not None:
|
if stride is not None:
|
||||||
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
||||||
raise ValueError("Stride is only usable with CTC models, try removing it")
|
raise ValueError("Stride is only usable with CTC models, try removing it")
|
||||||
|
|||||||
@@ -748,6 +748,7 @@ class Pipeline(_ScikitCompat):
|
|||||||
task: str = "",
|
task: str = "",
|
||||||
args_parser: ArgumentHandler = None,
|
args_parser: ArgumentHandler = None,
|
||||||
device: Union[int, str, "torch.device"] = -1,
|
device: Union[int, str, "torch.device"] = -1,
|
||||||
|
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
|
||||||
binary_output: bool = False,
|
binary_output: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -771,6 +772,7 @@ class Pipeline(_ScikitCompat):
|
|||||||
self.device = torch.device(f"cuda:{device}")
|
self.device = torch.device(f"cuda:{device}")
|
||||||
else:
|
else:
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
self.binary_output = binary_output
|
self.binary_output = binary_output
|
||||||
|
|
||||||
# Special handling
|
# Special handling
|
||||||
|
|||||||
@@ -300,3 +300,11 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_small_model_fp16(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device=0, torch_dtype=torch.float16)
|
||||||
|
pipe("This is a test")
|
||||||
|
|||||||
Reference in New Issue
Block a user