[pipeline] A simple fix for half-precision & 8bit models (#21479)
* v1 fix * adapt from suggestions * make style * fix tests * add gpu tests * update docs * fix other tests * Apply suggestions from code review Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * better fix * make fixup * better example * revert changes * proposal * more elegant solution * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -312,3 +312,12 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
||||
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device=0, torch_dtype=torch.float16)
|
||||
pipe("This is a test")
|
||||
|
||||
@require_torch
|
||||
@require_accelerate
|
||||
@require_torch_gpu
|
||||
def test_pipeline_accelerate_top_p(self):
|
||||
import torch
|
||||
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16)
|
||||
pipe("This is a test", do_sample=True, top_p=0.5)
|
||||
|
||||
Reference in New Issue
Block a user