Fix and simplify semantic-segmentation example (#30145)
* Remove unused augmentation * Fix pad_if_smaller() and remove unused augmentation * Add indentation * Fix requirements * Update dataset use instructions * Replace transforms with albumentations * Replace identity transform with None * Fixing formatting * Fixed comment place
This commit is contained in:
committed by
GitHub
parent
41579763ee
commit
56d001b26f
@@ -16,21 +16,20 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import albumentations as A
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
@@ -57,118 +56,19 @@ check_min_version("4.40.0.dev0")
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
||||
def pad_if_smaller(img, size, fill=0):
|
||||
size = (size, size) if isinstance(size, int) else size
|
||||
original_width, original_height = img.size
|
||||
pad_height = size[1] - original_height if original_height < size[1] else 0
|
||||
pad_width = size[0] - original_width if original_width < size[0] else 0
|
||||
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
|
||||
return img
|
||||
def reduce_labels_transform(labels: np.ndarray, **kwargs) -> np.ndarray:
|
||||
"""Set `0` label as with value 255 and then reduce all other labels by 1.
|
||||
|
||||
Example:
|
||||
Initial class labels: 0 - background; 1 - road; 2 - car;
|
||||
Transformed class labels: 255 - background; 0 - road; 1 - car;
|
||||
|
||||
class Compose:
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, image, target):
|
||||
for t in self.transforms:
|
||||
image, target = t(image, target)
|
||||
return image, target
|
||||
|
||||
|
||||
class Identity:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, image, target):
|
||||
return image, target
|
||||
|
||||
|
||||
class Resize:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = functional.resize(image, self.size)
|
||||
target = functional.resize(target, self.size, interpolation=transforms.InterpolationMode.NEAREST)
|
||||
return image, target
|
||||
|
||||
|
||||
class RandomResize:
|
||||
def __init__(self, min_size, max_size=None):
|
||||
self.min_size = min_size
|
||||
if max_size is None:
|
||||
max_size = min_size
|
||||
self.max_size = max_size
|
||||
|
||||
def __call__(self, image, target):
|
||||
size = random.randint(self.min_size, self.max_size)
|
||||
image = functional.resize(image, size)
|
||||
target = functional.resize(target, size, interpolation=transforms.InterpolationMode.NEAREST)
|
||||
return image, target
|
||||
|
||||
|
||||
class RandomCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size if isinstance(size, tuple) else (size, size)
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = pad_if_smaller(image, self.size)
|
||||
target = pad_if_smaller(target, self.size, fill=255)
|
||||
crop_params = transforms.RandomCrop.get_params(image, self.size)
|
||||
image = functional.crop(image, *crop_params)
|
||||
target = functional.crop(target, *crop_params)
|
||||
return image, target
|
||||
|
||||
|
||||
class RandomHorizontalFlip:
|
||||
def __init__(self, flip_prob):
|
||||
self.flip_prob = flip_prob
|
||||
|
||||
def __call__(self, image, target):
|
||||
if random.random() < self.flip_prob:
|
||||
image = functional.hflip(image)
|
||||
target = functional.hflip(target)
|
||||
return image, target
|
||||
|
||||
|
||||
class PILToTensor:
|
||||
def __call__(self, image, target):
|
||||
image = functional.pil_to_tensor(image)
|
||||
target = torch.as_tensor(np.array(target), dtype=torch.int64)
|
||||
return image, target
|
||||
|
||||
|
||||
class ConvertImageDtype:
|
||||
def __init__(self, dtype):
|
||||
self.dtype = dtype
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = functional.convert_image_dtype(image, self.dtype)
|
||||
return image, target
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = functional.normalize(image, mean=self.mean, std=self.std)
|
||||
return image, target
|
||||
|
||||
|
||||
class ReduceLabels:
|
||||
def __call__(self, image, target):
|
||||
if not isinstance(target, np.ndarray):
|
||||
target = np.array(target).astype(np.uint8)
|
||||
# avoid using underflow conversion
|
||||
target[target == 0] = 255
|
||||
target = target - 1
|
||||
target[target == 254] = 255
|
||||
|
||||
target = Image.fromarray(target)
|
||||
return image, target
|
||||
**kwargs are required to use this function with albumentations.
|
||||
"""
|
||||
labels[labels == 0] = 255
|
||||
labels = labels - 1
|
||||
labels[labels == 254] = 255
|
||||
return labels
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -365,7 +265,7 @@ def main():
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
label2id = {v: str(k) for k, v in id2label.items()}
|
||||
|
||||
# Load the mean IoU metric from the datasets package
|
||||
# Load the mean IoU metric from the evaluate package
|
||||
metric = evaluate.load("mean_iou", cache_dir=model_args.cache_dir)
|
||||
|
||||
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
|
||||
@@ -424,64 +324,62 @@ def main():
|
||||
token=model_args.token,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
# `reduce_labels` is a property of dataset labels, in case we use image_processor
|
||||
# pretrained on another dataset we should override the default setting
|
||||
image_processor.do_reduce_labels = data_args.reduce_labels
|
||||
|
||||
# Define torchvision transforms to be applied to each image + target.
|
||||
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
|
||||
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
|
||||
# Define transforms to be applied to each image and target.
|
||||
if "shortest_edge" in image_processor.size:
|
||||
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
|
||||
size = (image_processor.size["shortest_edge"], image_processor.size["shortest_edge"])
|
||||
height, width = image_processor.size["shortest_edge"], image_processor.size["shortest_edge"]
|
||||
else:
|
||||
size = (image_processor.size["height"], image_processor.size["width"])
|
||||
train_transforms = Compose(
|
||||
height, width = image_processor.size["height"], image_processor.size["width"]
|
||||
train_transforms = A.Compose(
|
||||
[
|
||||
ReduceLabels() if data_args.reduce_labels else Identity(),
|
||||
RandomCrop(size=size),
|
||||
RandomHorizontalFlip(flip_prob=0.5),
|
||||
PILToTensor(),
|
||||
ConvertImageDtype(torch.float),
|
||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
|
||||
A.Lambda(
|
||||
name="reduce_labels",
|
||||
mask=reduce_labels_transform if data_args.reduce_labels else None,
|
||||
p=1.0,
|
||||
),
|
||||
# pad image with 255, because it is ignored by loss
|
||||
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=255, p=1.0),
|
||||
A.RandomCrop(height=height, width=width, p=1.0),
|
||||
A.HorizontalFlip(p=0.5),
|
||||
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
# Define torchvision transform to be applied to each image.
|
||||
# jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
|
||||
val_transforms = Compose(
|
||||
val_transforms = A.Compose(
|
||||
[
|
||||
ReduceLabels() if data_args.reduce_labels else Identity(),
|
||||
Resize(size=size),
|
||||
PILToTensor(),
|
||||
ConvertImageDtype(torch.float),
|
||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
|
||||
A.Lambda(
|
||||
name="reduce_labels",
|
||||
mask=reduce_labels_transform if data_args.reduce_labels else None,
|
||||
p=1.0,
|
||||
),
|
||||
A.Resize(height=height, width=width, p=1.0),
|
||||
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_train(example_batch):
|
||||
def preprocess_batch(example_batch, transforms: A.Compose):
|
||||
pixel_values = []
|
||||
labels = []
|
||||
for image, target in zip(example_batch["image"], example_batch["label"]):
|
||||
image, target = train_transforms(image.convert("RGB"), target)
|
||||
pixel_values.append(image)
|
||||
labels.append(target)
|
||||
transformed = transforms(image=np.array(image.convert("RGB")), mask=np.array(target))
|
||||
pixel_values.append(transformed["image"])
|
||||
labels.append(transformed["mask"])
|
||||
|
||||
encoding = {}
|
||||
encoding["pixel_values"] = torch.stack(pixel_values)
|
||||
encoding["labels"] = torch.stack(labels)
|
||||
encoding["pixel_values"] = torch.stack(pixel_values).to(torch.float)
|
||||
encoding["labels"] = torch.stack(labels).to(torch.long)
|
||||
|
||||
return encoding
|
||||
|
||||
def preprocess_val(example_batch):
|
||||
pixel_values = []
|
||||
labels = []
|
||||
for image, target in zip(example_batch["image"], example_batch["label"]):
|
||||
image, target = val_transforms(image.convert("RGB"), target)
|
||||
pixel_values.append(image)
|
||||
labels.append(target)
|
||||
|
||||
encoding = {}
|
||||
encoding["pixel_values"] = torch.stack(pixel_values)
|
||||
encoding["labels"] = torch.stack(labels)
|
||||
|
||||
return encoding
|
||||
# Preprocess function for dataset should have only one argument,
|
||||
# so we use partial to pass the transforms
|
||||
preprocess_train_batch_fn = partial(preprocess_batch, transforms=train_transforms)
|
||||
preprocess_val_batch_fn = partial(preprocess_batch, transforms=val_transforms)
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in dataset:
|
||||
@@ -491,7 +389,7 @@ def main():
|
||||
dataset["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
|
||||
)
|
||||
# Set the training transforms
|
||||
dataset["train"].set_transform(preprocess_train)
|
||||
dataset["train"].set_transform(preprocess_train_batch_fn)
|
||||
|
||||
if training_args.do_eval:
|
||||
if "validation" not in dataset:
|
||||
@@ -501,7 +399,7 @@ def main():
|
||||
dataset["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
||||
)
|
||||
# Set the validation transforms
|
||||
dataset["validation"].set_transform(preprocess_val)
|
||||
dataset["validation"].set_transform(preprocess_val_batch_fn)
|
||||
|
||||
# Initialize our trainer
|
||||
trainer = Trainer(
|
||||
|
||||
Reference in New Issue
Block a user