Fix pytorch image classification example (#14883)
* Update example * Remove skip in tests
This commit is contained in:
@@ -279,12 +279,14 @@ def main():
|
|||||||
|
|
||||||
def train_transforms(example_batch):
|
def train_transforms(example_batch):
|
||||||
"""Apply _train_transforms across a batch."""
|
"""Apply _train_transforms across a batch."""
|
||||||
example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
|
example_batch["pixel_values"] = [
|
||||||
|
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
|
||||||
|
]
|
||||||
return example_batch
|
return example_batch
|
||||||
|
|
||||||
def val_transforms(example_batch):
|
def val_transforms(example_batch):
|
||||||
"""Apply _val_transforms across a batch."""
|
"""Apply _val_transforms across a batch."""
|
||||||
example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
|
example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
|
||||||
return example_batch
|
return example_batch
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -409,7 +408,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
result = get_results(tmp_dir)
|
result = get_results(tmp_dir)
|
||||||
self.assertGreaterEqual(result["eval_bleu"], 30)
|
self.assertGreaterEqual(result["eval_bleu"], 30)
|
||||||
|
|
||||||
@unittest.skip("Fix me Nate!")
|
|
||||||
def test_run_image_classification(self):
|
def test_run_image_classification(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|||||||
Reference in New Issue
Block a user