Fix pytorch image classification example (#14883)

* Update example

* Remove skip in tests
This commit is contained in:
Mario Šaško
2021-12-22 14:42:19 +01:00
committed by GitHub
parent 7df4b90c76
commit 1045a36c1f
2 changed files with 4 additions and 4 deletions

View File

@@ -279,12 +279,14 @@ def main():
def train_transforms(example_batch): def train_transforms(example_batch):
"""Apply _train_transforms across a batch.""" """Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]] example_batch["pixel_values"] = [
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
]
return example_batch return example_batch
def val_transforms(example_batch): def val_transforms(example_batch):
"""Apply _val_transforms across a batch.""" """Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]] example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
return example_batch return example_batch
if training_args.do_train: if training_args.do_train:

View File

@@ -19,7 +19,6 @@ import json
import logging import logging
import os import os
import sys import sys
import unittest
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@@ -409,7 +408,6 @@ class ExamplesTests(TestCasePlus):
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_bleu"], 30) self.assertGreaterEqual(result["eval_bleu"], 30)
@unittest.skip("Fix me Nate!")
def test_run_image_classification(self): def test_run_image_classification(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)