[pipeline] A simple fix for half-precision & 8bit models (#21479)

* v1 fix

* adapt from suggestions

* make style

* fix tests

* add gpu tests

* update docs

* fix other tests

* Apply suggestions from code review

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* better fix

* make fixup

* better example

* revert changes

* proposal

* more elegant solution

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2023-02-10 10:26:17 +01:00
committed by GitHub
parent 97d3390fc8
commit f83942684d
6 changed files with 63 additions and 12 deletions

View File

@@ -312,3 +312,12 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device=0, torch_dtype=torch.float16)
pipe("This is a test")
@require_torch
@require_accelerate
@require_torch_gpu
def test_pipeline_accelerate_top_p(self):
import torch
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16)
pipe("This is a test", do_sample=True, top_p=0.5)