Adding support for hidden_states and attentions in unbatching (#14420)

support.
This commit is contained in:
Nicolas Patry
2021-11-19 15:37:52 +01:00
committed by GitHub
parent f25a9332e8
commit 81fe8afaac
2 changed files with 22 additions and 3 deletions

View File

@@ -27,6 +27,7 @@ from transformers import (
TOKENIZER_MAPPING,
AutoFeatureExtractor,
AutoTokenizer,
DistilBertForSequenceClassification,
IBertConfig,
RobertaConfig,
TextClassificationPipeline,
@@ -322,6 +323,19 @@ class CommonPipelineTest(unittest.TestCase):
results.append(out)
self.assertEqual(len(results), 10)
@require_torch
def test_unbatch_attentions_hidden_states(self):
model = DistilBertForSequenceClassification.from_pretrained(
"Narsil/tiny-distilbert-sequence-classification", output_hidden_states=True, output_attentions=True
)
tokenizer = AutoTokenizer.from_pretrained("Narsil/tiny-distilbert-sequence-classification")
text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
# Used to throw an error because `hidden_states` are a tuple of tensors
# instead of the expected tensor.
outputs = text_classifier(["This is great !"] * 20, batch_size=32)
self.assertEqual(len(outputs), 20)
@is_pipeline_test
class PipelinePadTest(unittest.TestCase):