[Large PR] Entire rework of pipelines. (#13308)

* Enabling dataset iteration on pipelines.

Enabling dataset iteration on pipelines.

Unifying parameters under `set_parameters` function.

Small fix.

Last fixes after rebase

Remove print.

Fixing text2text `generate_kwargs`

No more `self.max_length`.

Fixing tf only conversational.

Consistency in start/stop index over TF/PT.

Speeding up drastically on TF (nasty bug where max_length would increase
a ton.)

Adding test for support for non fast tokenizers.

Fixign GPU usage on zero-shot.

Fix working on Tf.

Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Small cleanup.

Remove all asserts + simple format.

* Fixing audio-classification for large PR.

* Overly explicity null checking.

* Encapsulating GPU/CPU pytorch manipulation directly within `base.py`.

* Removed internal state for parameters of the  pipeline.

Instead of overriding implicitly internal state, we moved
to real named arguments on every `preprocess`, `_forward`,
`postprocess` function.

Instead `_sanitize_parameters` will be used to split all kwargs
of both __init__ and __call__ into the 3 kinds of named parameters.

* Move import warnings.

* Small fixes.

* Quality.

* Another small fix, using the CI to debug faster.

* Last fixes.

* Last fix.

* Small cleanup of tensor moving.

* is not None.

* Adding a bunch of docs + a iteration test.

* Fixing doc style.

* KeyDataset = None guard.

* RRemoving the Cuda test for pipelines (was testing).

* Even more simple iteration test.

* Correct import .

* Long day.

* Fixes in docs.

* [WIP] migrating object detection.

* Fixed the target_size bug.

* Fixup.

* Bad variable name.

* Fixing `ensure_on_device` respects original ModelOutput.
This commit is contained in:
Nicolas Patry
2021-09-10 14:47:48 +02:00
committed by GitHub
parent 09549aa18c
commit c63fcabfe9
28 changed files with 1559 additions and 1290 deletions

View File

@@ -96,7 +96,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
def run_aggregation_strategy(self, model, tokenizer):
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="simple")
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE)
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.SIMPLE)
outputs = token_classifier("A simple string")
self.assertIsInstance(outputs, list)
n = len(outputs)
@@ -115,7 +115,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
)
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="first")
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST)
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.FIRST)
outputs = token_classifier("A simple string")
self.assertIsInstance(outputs, list)
n = len(outputs)
@@ -134,7 +134,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
)
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="max")
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.MAX)
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.MAX)
outputs = token_classifier("A simple string")
self.assertIsInstance(outputs, list)
n = len(outputs)
@@ -155,7 +155,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
token_classifier = TokenClassificationPipeline(
model=model, tokenizer=tokenizer, aggregation_strategy="average"
)
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.AVERAGE)
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.AVERAGE)
outputs = token_classifier("A simple string")
self.assertIsInstance(outputs, list)
n = len(outputs)
@@ -175,12 +175,12 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
with self.assertWarns(UserWarning):
token_classifier = pipeline(task="ner", model=model, tokenizer=tokenizer, grouped_entities=True)
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE)
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.SIMPLE)
with self.assertWarns(UserWarning):
token_classifier = pipeline(
task="ner", model=model, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=True
)
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST)
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.FIRST)
@require_torch
@slow
@@ -533,7 +533,12 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
scores = np.array([[1, 0, 0], [0.1, 0.3, 0.6], [0.8, 0.1, 0.1]])
pre_entities = token_classifier.gather_pre_entities(
sentence, input_ids, scores, offset_mapping, special_tokens_mask
sentence,
input_ids,
scores,
offset_mapping,
special_tokens_mask,
aggregation_strategy=AggregationStrategy.NONE,
)
self.assertEqual(
nested_simplify(pre_entities),
@@ -570,6 +575,20 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
],
)
@require_torch
def test_no_offset_tokenizer(self):
model_name = "Narsil/small2"
tokenizer = AutoTokenizer.from_pretrained("Narsil/small2", use_fast=False)
token_classifier = pipeline(task="token-classification", model=model_name, tokenizer=tokenizer, framework="pt")
outputs = token_classifier("This is a test !")
self.assertEqual(
nested_simplify(outputs),
[
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": None, "end": None},
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": None, "end": None},
],
)
@require_torch
def test_small_model_pt(self):
model_name = "Narsil/small2"