Fix FP16 inference in TextGenerationPipeline (#20913)
* add torch_dtype attribute to Pipeline * Use torch_dtype to cast input tensor type in AutomaticSpeechRecognitionPipeline * Fix code quality * Add TextGenerationPipeline fp16 test * Fix code quality * Remove useless require in tests Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@@ -300,3 +300,11 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
def test_small_model_fp16(self):
|
||||
import torch
|
||||
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device=0, torch_dtype=torch.float16)
|
||||
pipe("This is a test")
|
||||
|
||||
Reference in New Issue
Block a user