From b56848c8c8d29822cb970abd56bc620cece937a3 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 17 Jun 2021 09:41:16 +0200 Subject: [PATCH] Pipeline update & tests (#12207) --- .../pipelines/image_classification.py | 6 +++- tests/test_pipelines_image_classification.py | 31 +++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/image_classification.py b/src/transformers/pipelines/image_classification.py index eb0410f322..76a519a988 100644 --- a/src/transformers/pipelines/image_classification.py +++ b/src/transformers/pipelines/image_classification.py @@ -87,7 +87,8 @@ class ImageClassificationPipeline(Pipeline): Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL images. top_k (:obj:`int`, `optional`, defaults to 5): - The number of top labels that will be returned by the pipeline. + The number of top labels that will be returned by the pipeline. If the provided number is higher than + the number of labels available in the model configuration, it will default to the number of labels. Return: A dictionary or a list of dictionaries containing result. If the input is a single image, will return a @@ -106,6 +107,9 @@ class ImageClassificationPipeline(Pipeline): images = [self.load_image(image) for image in images] + if top_k > self.model.config.num_labels: + top_k = self.model.config.num_labels + with torch.no_grad(): inputs = self.feature_extractor(images=images, return_tensors="pt") outputs = self.model(**inputs) diff --git a/tests/test_pipelines_image_classification.py b/tests/test_pipelines_image_classification.py index ecfab4c76d..0306523255 100644 --- a/tests/test_pipelines_image_classification.py +++ b/tests/test_pipelines_image_classification.py @@ -15,6 +15,7 @@ import unittest from transformers import ( + AutoConfig, AutoFeatureExtractor, AutoModelForImageClassification, PreTrainedTokenizer, @@ -128,3 +129,33 @@ class ImageClassificationPipelineTests(unittest.TestCase): image_classifier = pipeline("image-classification", model=self.small_models[0], tokenizer=tokenizer) self.assertIs(image_classifier.tokenizer, tokenizer) + + def test_num_labels_inferior_to_topk(self): + for small_model in self.small_models: + + num_labels = 2 + model = AutoModelForImageClassification.from_config( + AutoConfig.from_pretrained(small_model, num_labels=num_labels) + ) + feature_extractor = AutoFeatureExtractor.from_pretrained(small_model) + image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor) + + for valid_input in self.valid_inputs: + output = image_classifier(**valid_input) + + def assert_valid_pipeline_output(pipeline_output): + self.assertTrue(isinstance(pipeline_output, list)) + self.assertEqual(len(pipeline_output), num_labels) + for label_result in pipeline_output: + self.assertTrue(isinstance(label_result, dict)) + self.assertIn("label", label_result) + self.assertIn("score", label_result) + + if isinstance(valid_input["images"], list): + # When images are batched, pipeline output is a list of lists of dictionaries + self.assertEqual(len(valid_input["images"]), len(output)) + for individual_output in output: + assert_valid_pipeline_output(individual_output) + else: + # When images are batched, pipeline output is a list of dictionaries + assert_valid_pipeline_output(output)