[Large PR] Entire rework of pipelines. (#13308)
* Enabling dataset iteration on pipelines. Enabling dataset iteration on pipelines. Unifying parameters under `set_parameters` function. Small fix. Last fixes after rebase Remove print. Fixing text2text `generate_kwargs` No more `self.max_length`. Fixing tf only conversational. Consistency in start/stop index over TF/PT. Speeding up drastically on TF (nasty bug where max_length would increase a ton.) Adding test for support for non fast tokenizers. Fixign GPU usage on zero-shot. Fix working on Tf. Update src/transformers/pipelines/base.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Update src/transformers/pipelines/base.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Small cleanup. Remove all asserts + simple format. * Fixing audio-classification for large PR. * Overly explicity null checking. * Encapsulating GPU/CPU pytorch manipulation directly within `base.py`. * Removed internal state for parameters of the pipeline. Instead of overriding implicitly internal state, we moved to real named arguments on every `preprocess`, `_forward`, `postprocess` function. Instead `_sanitize_parameters` will be used to split all kwargs of both __init__ and __call__ into the 3 kinds of named parameters. * Move import warnings. * Small fixes. * Quality. * Another small fix, using the CI to debug faster. * Last fixes. * Last fix. * Small cleanup of tensor moving. * is not None. * Adding a bunch of docs + a iteration test. * Fixing doc style. * KeyDataset = None guard. * RRemoving the Cuda test for pipelines (was testing). * Even more simple iteration test. * Correct import . * Long day. * Fixes in docs. * [WIP] migrating object detection. * Fixed the target_size bug. * Fixup. * Bad variable name. * Fixing `ensure_on_device` respects original ModelOutput.
This commit is contained in:
@@ -15,11 +15,13 @@
|
||||
import importlib
|
||||
import logging
|
||||
import string
|
||||
import unittest
|
||||
from abc import abstractmethod
|
||||
from functools import lru_cache
|
||||
from unittest import skipIf
|
||||
|
||||
from transformers import FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer
|
||||
from transformers import FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer, pipeline
|
||||
from transformers.testing_utils import is_pipeline_test, require_torch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -177,3 +179,30 @@ class PipelineTestCaseMeta(type):
|
||||
dct["test_small_model_tf"] = dct.get("test_small_model_tf", inner)
|
||||
|
||||
return type.__new__(mcs, name, bases, dct)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class CommonPipelineTest(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_pipeline_iteration(self):
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class MyDataset(Dataset):
|
||||
data = [
|
||||
"This is a test",
|
||||
"This restaurant is great",
|
||||
"This restaurant is awful",
|
||||
]
|
||||
|
||||
def __len__(self):
|
||||
return 3
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.data[i]
|
||||
|
||||
text_classifier = pipeline(
|
||||
task="text-classification", model="Narsil/tiny-distilbert-sequence-classification", framework="pt"
|
||||
)
|
||||
dataset = MyDataset()
|
||||
for output in text_classifier(dataset):
|
||||
self.assertEqual(output, {"label": ANY(str), "score": ANY(float)})
|
||||
|
||||
Reference in New Issue
Block a user