Adding support for hidden_states and attentions in unbatching (#14420)
support.
This commit is contained in:
@@ -747,9 +747,14 @@ if is_torch_available():
|
|||||||
else:
|
else:
|
||||||
loader_batched = {}
|
loader_batched = {}
|
||||||
for k, element in self._loader_batch_data.items():
|
for k, element in self._loader_batch_data.items():
|
||||||
if k == "past_key_values":
|
if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple):
|
||||||
continue
|
if isinstance(element[0], torch.Tensor):
|
||||||
if isinstance(element[self._loader_batch_index], torch.Tensor):
|
loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
|
||||||
|
elif isinstance(element[0], np.ndarray):
|
||||||
|
loader_batched[k] = tuple(
|
||||||
|
np.expand_dims(el[self._loader_batch_index], 0) for el in element
|
||||||
|
)
|
||||||
|
elif isinstance(element[self._loader_batch_index], torch.Tensor):
|
||||||
loader_batched[k] = element[self._loader_batch_index].unsqueeze(0)
|
loader_batched[k] = element[self._loader_batch_index].unsqueeze(0)
|
||||||
elif isinstance(element[self._loader_batch_index], np.ndarray):
|
elif isinstance(element[self._loader_batch_index], np.ndarray):
|
||||||
loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0)
|
loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from transformers import (
|
|||||||
TOKENIZER_MAPPING,
|
TOKENIZER_MAPPING,
|
||||||
AutoFeatureExtractor,
|
AutoFeatureExtractor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
DistilBertForSequenceClassification,
|
||||||
IBertConfig,
|
IBertConfig,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
TextClassificationPipeline,
|
TextClassificationPipeline,
|
||||||
@@ -322,6 +323,19 @@ class CommonPipelineTest(unittest.TestCase):
|
|||||||
results.append(out)
|
results.append(out)
|
||||||
self.assertEqual(len(results), 10)
|
self.assertEqual(len(results), 10)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_unbatch_attentions_hidden_states(self):
|
||||||
|
model = DistilBertForSequenceClassification.from_pretrained(
|
||||||
|
"Narsil/tiny-distilbert-sequence-classification", output_hidden_states=True, output_attentions=True
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Narsil/tiny-distilbert-sequence-classification")
|
||||||
|
text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
# Used to throw an error because `hidden_states` are a tuple of tensors
|
||||||
|
# instead of the expected tensor.
|
||||||
|
outputs = text_classifier(["This is great !"] * 20, batch_size=32)
|
||||||
|
self.assertEqual(len(outputs), 20)
|
||||||
|
|
||||||
|
|
||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
class PipelinePadTest(unittest.TestCase):
|
class PipelinePadTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user