Adding support for hidden_states and attentions in unbatching (#14420)
support.
This commit is contained in:
@@ -747,9 +747,14 @@ if is_torch_available():
|
||||
else:
|
||||
loader_batched = {}
|
||||
for k, element in self._loader_batch_data.items():
|
||||
if k == "past_key_values":
|
||||
continue
|
||||
if isinstance(element[self._loader_batch_index], torch.Tensor):
|
||||
if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple):
|
||||
if isinstance(element[0], torch.Tensor):
|
||||
loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
|
||||
elif isinstance(element[0], np.ndarray):
|
||||
loader_batched[k] = tuple(
|
||||
np.expand_dims(el[self._loader_batch_index], 0) for el in element
|
||||
)
|
||||
elif isinstance(element[self._loader_batch_index], torch.Tensor):
|
||||
loader_batched[k] = element[self._loader_batch_index].unsqueeze(0)
|
||||
elif isinstance(element[self._loader_batch_index], np.ndarray):
|
||||
loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0)
|
||||
|
||||
@@ -27,6 +27,7 @@ from transformers import (
|
||||
TOKENIZER_MAPPING,
|
||||
AutoFeatureExtractor,
|
||||
AutoTokenizer,
|
||||
DistilBertForSequenceClassification,
|
||||
IBertConfig,
|
||||
RobertaConfig,
|
||||
TextClassificationPipeline,
|
||||
@@ -322,6 +323,19 @@ class CommonPipelineTest(unittest.TestCase):
|
||||
results.append(out)
|
||||
self.assertEqual(len(results), 10)
|
||||
|
||||
@require_torch
|
||||
def test_unbatch_attentions_hidden_states(self):
|
||||
model = DistilBertForSequenceClassification.from_pretrained(
|
||||
"Narsil/tiny-distilbert-sequence-classification", output_hidden_states=True, output_attentions=True
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Narsil/tiny-distilbert-sequence-classification")
|
||||
text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
# Used to throw an error because `hidden_states` are a tuple of tensors
|
||||
# instead of the expected tensor.
|
||||
outputs = text_classifier(["This is great !"] * 20, batch_size=32)
|
||||
self.assertEqual(len(outputs), 20)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class PipelinePadTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user