From 1045a36c1f670fb6d18d71260bc9905f5b551843 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20=C5=A0a=C5=A1ko?= Date: Wed, 22 Dec 2021 14:42:19 +0100 Subject: [PATCH] Fix pytorch image classification example (#14883) * Update example * Remove skip in tests --- .../image-classification/run_image_classification.py | 6 ++++-- examples/pytorch/test_examples.py | 2 -- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py index f133ddbd2c..62e1cf49d6 100644 --- a/examples/pytorch/image-classification/run_image_classification.py +++ b/examples/pytorch/image-classification/run_image_classification.py @@ -279,12 +279,14 @@ def main(): def train_transforms(example_batch): """Apply _train_transforms across a batch.""" - example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]] + example_batch["pixel_values"] = [ + _train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"] + ] return example_batch def val_transforms(example_batch): """Apply _val_transforms across a batch.""" - example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]] + example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]] return example_batch if training_args.do_train: diff --git a/examples/pytorch/test_examples.py b/examples/pytorch/test_examples.py index dee4e05511..1a1c2ea06a 100644 --- a/examples/pytorch/test_examples.py +++ b/examples/pytorch/test_examples.py @@ -19,7 +19,6 @@ import json import logging import os import sys -import unittest from unittest.mock import patch import torch @@ -409,7 +408,6 @@ class ExamplesTests(TestCasePlus): result = get_results(tmp_dir) self.assertGreaterEqual(result["eval_bleu"], 30) - @unittest.skip("Fix me Nate!") def test_run_image_classification(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler)