[Large PR] Entire rework of pipelines. (#13308)
* Enabling dataset iteration on pipelines. Enabling dataset iteration on pipelines. Unifying parameters under `set_parameters` function. Small fix. Last fixes after rebase Remove print. Fixing text2text `generate_kwargs` No more `self.max_length`. Fixing tf only conversational. Consistency in start/stop index over TF/PT. Speeding up drastically on TF (nasty bug where max_length would increase a ton.) Adding test for support for non fast tokenizers. Fixign GPU usage on zero-shot. Fix working on Tf. Update src/transformers/pipelines/base.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Update src/transformers/pipelines/base.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Small cleanup. Remove all asserts + simple format. * Fixing audio-classification for large PR. * Overly explicity null checking. * Encapsulating GPU/CPU pytorch manipulation directly within `base.py`. * Removed internal state for parameters of the pipeline. Instead of overriding implicitly internal state, we moved to real named arguments on every `preprocess`, `_forward`, `postprocess` function. Instead `_sanitize_parameters` will be used to split all kwargs of both __init__ and __call__ into the 3 kinds of named parameters. * Move import warnings. * Small fixes. * Quality. * Another small fix, using the CI to debug faster. * Last fixes. * Last fix. * Small cleanup of tensor moving. * is not None. * Adding a bunch of docs + a iteration test. * Fixing doc style. * KeyDataset = None guard. * RRemoving the Cuda test for pipelines (was testing). * Even more simple iteration test. * Correct import . * Long day. * Fixes in docs. * [WIP] migrating object detection. * Fixed the target_size bug. * Fixup. * Bad variable name. * Fixing `ensure_on_device` respects original ModelOutput.
This commit is contained in:
@@ -96,7 +96,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
|
||||
def run_aggregation_strategy(self, model, tokenizer):
|
||||
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="simple")
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE)
|
||||
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.SIMPLE)
|
||||
outputs = token_classifier("A simple string")
|
||||
self.assertIsInstance(outputs, list)
|
||||
n = len(outputs)
|
||||
@@ -115,7 +115,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
)
|
||||
|
||||
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="first")
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST)
|
||||
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.FIRST)
|
||||
outputs = token_classifier("A simple string")
|
||||
self.assertIsInstance(outputs, list)
|
||||
n = len(outputs)
|
||||
@@ -134,7 +134,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
)
|
||||
|
||||
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="max")
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.MAX)
|
||||
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.MAX)
|
||||
outputs = token_classifier("A simple string")
|
||||
self.assertIsInstance(outputs, list)
|
||||
n = len(outputs)
|
||||
@@ -155,7 +155,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
token_classifier = TokenClassificationPipeline(
|
||||
model=model, tokenizer=tokenizer, aggregation_strategy="average"
|
||||
)
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.AVERAGE)
|
||||
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.AVERAGE)
|
||||
outputs = token_classifier("A simple string")
|
||||
self.assertIsInstance(outputs, list)
|
||||
n = len(outputs)
|
||||
@@ -175,12 +175,12 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
token_classifier = pipeline(task="ner", model=model, tokenizer=tokenizer, grouped_entities=True)
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE)
|
||||
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.SIMPLE)
|
||||
with self.assertWarns(UserWarning):
|
||||
token_classifier = pipeline(
|
||||
task="ner", model=model, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=True
|
||||
)
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST)
|
||||
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.FIRST)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
@@ -533,7 +533,12 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
scores = np.array([[1, 0, 0], [0.1, 0.3, 0.6], [0.8, 0.1, 0.1]])
|
||||
|
||||
pre_entities = token_classifier.gather_pre_entities(
|
||||
sentence, input_ids, scores, offset_mapping, special_tokens_mask
|
||||
sentence,
|
||||
input_ids,
|
||||
scores,
|
||||
offset_mapping,
|
||||
special_tokens_mask,
|
||||
aggregation_strategy=AggregationStrategy.NONE,
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(pre_entities),
|
||||
@@ -570,6 +575,20 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_no_offset_tokenizer(self):
|
||||
model_name = "Narsil/small2"
|
||||
tokenizer = AutoTokenizer.from_pretrained("Narsil/small2", use_fast=False)
|
||||
token_classifier = pipeline(task="token-classification", model=model_name, tokenizer=tokenizer, framework="pt")
|
||||
outputs = token_classifier("This is a test !")
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[
|
||||
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": None, "end": None},
|
||||
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": None, "end": None},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
model_name = "Narsil/small2"
|
||||
|
||||
Reference in New Issue
Block a user