🚨🚨🚨 [pipelines] update defaults in pipelines that can generate (#38129)
* pipeline generation defaults * add max_new_tokens=20 in test pipelines * pop all kwargs that are used to parameterize generation config * add class attr that tell us whether a pipeline calls generate * tmp commit * pt text gen pipeline tests passing * remove failing tf tests * fix text gen pipeline mixin test corner case * update text_to_audio pipeline tests * trigger tests * a few more tests * skips * some more audio tests * not slow * broken * lower severity of generation mode errors * fix all asr pipeline tests * nit * skip * image to text pipeline tests * text2test pipeline * last pipelines * fix flaky * PR comments * handle generate attrs more carefully in models that cant generate * same as above
This commit is contained in:
@@ -21,7 +21,7 @@ from transformers import (
|
||||
TFPreTrainedModel,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import is_pipeline_test, require_torch, slow, torch_device
|
||||
from transformers.tokenization_utils import TruncationStrategy
|
||||
|
||||
from .test_pipelines_common import ANY
|
||||
@@ -48,6 +48,7 @@ class SummarizationPipelineTests(unittest.TestCase):
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
max_new_tokens=20,
|
||||
)
|
||||
return summarizer, ["(CNN)The Palestinian Authority officially became", "Some other text"]
|
||||
|
||||
@@ -92,20 +93,7 @@ class SummarizationPipelineTests(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
summarizer = pipeline(task="summarization", model="sshleifer/tiny-mbart", framework="pt")
|
||||
outputs = summarizer("This is a small test")
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"summary_text": "เข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไปเข้าไป"
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
summarizer = pipeline(task="summarization", model="sshleifer/tiny-mbart", framework="tf")
|
||||
summarizer = pipeline(task="summarization", model="sshleifer/tiny-mbart", framework="pt", max_new_tokens=19)
|
||||
outputs = summarizer("This is a small test")
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
|
||||
Reference in New Issue
Block a user