Fix FP16 inference in TextGenerationPipeline (#20913)

* add torch_dtype attribute to Pipeline

* Use torch_dtype to cast input tensor type in AutomaticSpeechRecognitionPipeline

* Fix code quality

* Add TextGenerationPipeline fp16 test

* Fix code quality

* Remove useless require in tests

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
bofeng huang
2022-12-29 08:19:25 +01:00
committed by GitHub
parent 11c49ed23b
commit fe65657de1
3 changed files with 16 additions and 6 deletions

View File

@@ -300,3 +300,11 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
}
],
)
@require_torch
@require_torch_gpu
def test_small_model_fp16(self):
import torch
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device=0, torch_dtype=torch.float16)
pipe("This is a test")