Adding support for hidden_states and attentions in unbatching (#14420)
support.
This commit is contained in:
@@ -27,6 +27,7 @@ from transformers import (
|
||||
TOKENIZER_MAPPING,
|
||||
AutoFeatureExtractor,
|
||||
AutoTokenizer,
|
||||
DistilBertForSequenceClassification,
|
||||
IBertConfig,
|
||||
RobertaConfig,
|
||||
TextClassificationPipeline,
|
||||
@@ -322,6 +323,19 @@ class CommonPipelineTest(unittest.TestCase):
|
||||
results.append(out)
|
||||
self.assertEqual(len(results), 10)
|
||||
|
||||
@require_torch
|
||||
def test_unbatch_attentions_hidden_states(self):
|
||||
model = DistilBertForSequenceClassification.from_pretrained(
|
||||
"Narsil/tiny-distilbert-sequence-classification", output_hidden_states=True, output_attentions=True
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Narsil/tiny-distilbert-sequence-classification")
|
||||
text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
# Used to throw an error because `hidden_states` are a tuple of tensors
|
||||
# instead of the expected tensor.
|
||||
outputs = text_classifier(["This is great !"] * 20, batch_size=32)
|
||||
self.assertEqual(len(outputs), 20)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class PipelinePadTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user