Fix and simplify semantic-segmentation example (#30145)
* Remove unused augmentation * Fix pad_if_smaller() and remove unused augmentation * Add indentation * Fix requirements * Update dataset use instructions * Replace transforms with albumentations * Replace identity transform with None * Fixing formatting * Fixed comment place
This commit is contained in:
committed by
GitHub
parent
41579763ee
commit
56d001b26f
@@ -25,3 +25,4 @@ torchaudio
|
|||||||
jiwer
|
jiwer
|
||||||
librosa
|
librosa
|
||||||
evaluate >= 0.2.0
|
evaluate >= 0.2.0
|
||||||
|
albumentations
|
||||||
|
|||||||
@@ -97,6 +97,10 @@ The script leverages the [🤗 Trainer API](https://huggingface.co/docs/transfor
|
|||||||
|
|
||||||
Here we show how to fine-tune a [SegFormer](https://huggingface.co/nvidia/mit-b0) model on the [segments/sidewalk-semantic](https://huggingface.co/datasets/segments/sidewalk-semantic) dataset:
|
Here we show how to fine-tune a [SegFormer](https://huggingface.co/nvidia/mit-b0) model on the [segments/sidewalk-semantic](https://huggingface.co/datasets/segments/sidewalk-semantic) dataset:
|
||||||
|
|
||||||
|
In order to use `segments/sidewalk-semantic`:
|
||||||
|
- Log in to Hugging Face with `huggingface-cli login` (token can be accessed [here](https://huggingface.co/settings/tokens)).
|
||||||
|
- Accept terms of use for `sidewalk-semantic` on [dataset page](https://huggingface.co/datasets/segments/sidewalk-semantic).
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python run_semantic_segmentation.py \
|
python run_semantic_segmentation.py \
|
||||||
--model_name_or_path nvidia/mit-b0 \
|
--model_name_or_path nvidia/mit-b0 \
|
||||||
@@ -105,7 +109,6 @@ python run_semantic_segmentation.py \
|
|||||||
--remove_unused_columns False \
|
--remove_unused_columns False \
|
||||||
--do_train \
|
--do_train \
|
||||||
--do_eval \
|
--do_eval \
|
||||||
--evaluation_strategy steps \
|
|
||||||
--push_to_hub \
|
--push_to_hub \
|
||||||
--push_to_hub_model_id segformer-finetuned-sidewalk-10k-steps \
|
--push_to_hub_model_id segformer-finetuned-sidewalk-10k-steps \
|
||||||
--max_steps 10000 \
|
--max_steps 10000 \
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
git://github.com/huggingface/accelerate.git
|
|
||||||
datasets >= 2.0.0
|
datasets >= 2.0.0
|
||||||
torch >= 1.3
|
torch >= 1.3
|
||||||
|
accelerate
|
||||||
evaluate
|
evaluate
|
||||||
|
Pillow
|
||||||
|
albumentations
|
||||||
@@ -16,21 +16,20 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import albumentations as A
|
||||||
import evaluate
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from albumentations.pytorch import ToTensorV2
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from PIL import Image
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchvision import transforms
|
|
||||||
from torchvision.transforms import functional
|
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -57,118 +56,19 @@ check_min_version("4.40.0.dev0")
|
|||||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||||
|
|
||||||
|
|
||||||
def pad_if_smaller(img, size, fill=0):
|
def reduce_labels_transform(labels: np.ndarray, **kwargs) -> np.ndarray:
|
||||||
size = (size, size) if isinstance(size, int) else size
|
"""Set `0` label as with value 255 and then reduce all other labels by 1.
|
||||||
original_width, original_height = img.size
|
|
||||||
pad_height = size[1] - original_height if original_height < size[1] else 0
|
|
||||||
pad_width = size[0] - original_width if original_width < size[0] else 0
|
|
||||||
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Initial class labels: 0 - background; 1 - road; 2 - car;
|
||||||
|
Transformed class labels: 255 - background; 0 - road; 1 - car;
|
||||||
|
|
||||||
class Compose:
|
**kwargs are required to use this function with albumentations.
|
||||||
def __init__(self, transforms):
|
"""
|
||||||
self.transforms = transforms
|
labels[labels == 0] = 255
|
||||||
|
labels = labels - 1
|
||||||
def __call__(self, image, target):
|
labels[labels == 254] = 255
|
||||||
for t in self.transforms:
|
return labels
|
||||||
image, target = t(image, target)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class Identity:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(self, image, target):
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class Resize:
|
|
||||||
def __init__(self, size):
|
|
||||||
self.size = size
|
|
||||||
|
|
||||||
def __call__(self, image, target):
|
|
||||||
image = functional.resize(image, self.size)
|
|
||||||
target = functional.resize(target, self.size, interpolation=transforms.InterpolationMode.NEAREST)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class RandomResize:
|
|
||||||
def __init__(self, min_size, max_size=None):
|
|
||||||
self.min_size = min_size
|
|
||||||
if max_size is None:
|
|
||||||
max_size = min_size
|
|
||||||
self.max_size = max_size
|
|
||||||
|
|
||||||
def __call__(self, image, target):
|
|
||||||
size = random.randint(self.min_size, self.max_size)
|
|
||||||
image = functional.resize(image, size)
|
|
||||||
target = functional.resize(target, size, interpolation=transforms.InterpolationMode.NEAREST)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class RandomCrop:
|
|
||||||
def __init__(self, size):
|
|
||||||
self.size = size if isinstance(size, tuple) else (size, size)
|
|
||||||
|
|
||||||
def __call__(self, image, target):
|
|
||||||
image = pad_if_smaller(image, self.size)
|
|
||||||
target = pad_if_smaller(target, self.size, fill=255)
|
|
||||||
crop_params = transforms.RandomCrop.get_params(image, self.size)
|
|
||||||
image = functional.crop(image, *crop_params)
|
|
||||||
target = functional.crop(target, *crop_params)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class RandomHorizontalFlip:
|
|
||||||
def __init__(self, flip_prob):
|
|
||||||
self.flip_prob = flip_prob
|
|
||||||
|
|
||||||
def __call__(self, image, target):
|
|
||||||
if random.random() < self.flip_prob:
|
|
||||||
image = functional.hflip(image)
|
|
||||||
target = functional.hflip(target)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class PILToTensor:
|
|
||||||
def __call__(self, image, target):
|
|
||||||
image = functional.pil_to_tensor(image)
|
|
||||||
target = torch.as_tensor(np.array(target), dtype=torch.int64)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class ConvertImageDtype:
|
|
||||||
def __init__(self, dtype):
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
def __call__(self, image, target):
|
|
||||||
image = functional.convert_image_dtype(image, self.dtype)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class Normalize:
|
|
||||||
def __init__(self, mean, std):
|
|
||||||
self.mean = mean
|
|
||||||
self.std = std
|
|
||||||
|
|
||||||
def __call__(self, image, target):
|
|
||||||
image = functional.normalize(image, mean=self.mean, std=self.std)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class ReduceLabels:
|
|
||||||
def __call__(self, image, target):
|
|
||||||
if not isinstance(target, np.ndarray):
|
|
||||||
target = np.array(target).astype(np.uint8)
|
|
||||||
# avoid using underflow conversion
|
|
||||||
target[target == 0] = 255
|
|
||||||
target = target - 1
|
|
||||||
target[target == 254] = 255
|
|
||||||
|
|
||||||
target = Image.fromarray(target)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -365,7 +265,7 @@ def main():
|
|||||||
id2label = {int(k): v for k, v in id2label.items()}
|
id2label = {int(k): v for k, v in id2label.items()}
|
||||||
label2id = {v: str(k) for k, v in id2label.items()}
|
label2id = {v: str(k) for k, v in id2label.items()}
|
||||||
|
|
||||||
# Load the mean IoU metric from the datasets package
|
# Load the mean IoU metric from the evaluate package
|
||||||
metric = evaluate.load("mean_iou", cache_dir=model_args.cache_dir)
|
metric = evaluate.load("mean_iou", cache_dir=model_args.cache_dir)
|
||||||
|
|
||||||
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
|
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
|
||||||
@@ -424,64 +324,62 @@ def main():
|
|||||||
token=model_args.token,
|
token=model_args.token,
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
# `reduce_labels` is a property of dataset labels, in case we use image_processor
|
||||||
|
# pretrained on another dataset we should override the default setting
|
||||||
|
image_processor.do_reduce_labels = data_args.reduce_labels
|
||||||
|
|
||||||
# Define torchvision transforms to be applied to each image + target.
|
# Define transforms to be applied to each image and target.
|
||||||
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
|
|
||||||
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
|
|
||||||
if "shortest_edge" in image_processor.size:
|
if "shortest_edge" in image_processor.size:
|
||||||
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
|
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
|
||||||
size = (image_processor.size["shortest_edge"], image_processor.size["shortest_edge"])
|
height, width = image_processor.size["shortest_edge"], image_processor.size["shortest_edge"]
|
||||||
else:
|
else:
|
||||||
size = (image_processor.size["height"], image_processor.size["width"])
|
height, width = image_processor.size["height"], image_processor.size["width"]
|
||||||
train_transforms = Compose(
|
train_transforms = A.Compose(
|
||||||
[
|
[
|
||||||
ReduceLabels() if data_args.reduce_labels else Identity(),
|
A.Lambda(
|
||||||
RandomCrop(size=size),
|
name="reduce_labels",
|
||||||
RandomHorizontalFlip(flip_prob=0.5),
|
mask=reduce_labels_transform if data_args.reduce_labels else None,
|
||||||
PILToTensor(),
|
p=1.0,
|
||||||
ConvertImageDtype(torch.float),
|
),
|
||||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
|
# pad image with 255, because it is ignored by loss
|
||||||
|
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=255, p=1.0),
|
||||||
|
A.RandomCrop(height=height, width=width, p=1.0),
|
||||||
|
A.HorizontalFlip(p=0.5),
|
||||||
|
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
|
||||||
|
ToTensorV2(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
# Define torchvision transform to be applied to each image.
|
val_transforms = A.Compose(
|
||||||
# jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
|
|
||||||
val_transforms = Compose(
|
|
||||||
[
|
[
|
||||||
ReduceLabels() if data_args.reduce_labels else Identity(),
|
A.Lambda(
|
||||||
Resize(size=size),
|
name="reduce_labels",
|
||||||
PILToTensor(),
|
mask=reduce_labels_transform if data_args.reduce_labels else None,
|
||||||
ConvertImageDtype(torch.float),
|
p=1.0,
|
||||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
|
),
|
||||||
|
A.Resize(height=height, width=width, p=1.0),
|
||||||
|
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
|
||||||
|
ToTensorV2(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def preprocess_train(example_batch):
|
def preprocess_batch(example_batch, transforms: A.Compose):
|
||||||
pixel_values = []
|
pixel_values = []
|
||||||
labels = []
|
labels = []
|
||||||
for image, target in zip(example_batch["image"], example_batch["label"]):
|
for image, target in zip(example_batch["image"], example_batch["label"]):
|
||||||
image, target = train_transforms(image.convert("RGB"), target)
|
transformed = transforms(image=np.array(image.convert("RGB")), mask=np.array(target))
|
||||||
pixel_values.append(image)
|
pixel_values.append(transformed["image"])
|
||||||
labels.append(target)
|
labels.append(transformed["mask"])
|
||||||
|
|
||||||
encoding = {}
|
encoding = {}
|
||||||
encoding["pixel_values"] = torch.stack(pixel_values)
|
encoding["pixel_values"] = torch.stack(pixel_values).to(torch.float)
|
||||||
encoding["labels"] = torch.stack(labels)
|
encoding["labels"] = torch.stack(labels).to(torch.long)
|
||||||
|
|
||||||
return encoding
|
return encoding
|
||||||
|
|
||||||
def preprocess_val(example_batch):
|
# Preprocess function for dataset should have only one argument,
|
||||||
pixel_values = []
|
# so we use partial to pass the transforms
|
||||||
labels = []
|
preprocess_train_batch_fn = partial(preprocess_batch, transforms=train_transforms)
|
||||||
for image, target in zip(example_batch["image"], example_batch["label"]):
|
preprocess_val_batch_fn = partial(preprocess_batch, transforms=val_transforms)
|
||||||
image, target = val_transforms(image.convert("RGB"), target)
|
|
||||||
pixel_values.append(image)
|
|
||||||
labels.append(target)
|
|
||||||
|
|
||||||
encoding = {}
|
|
||||||
encoding["pixel_values"] = torch.stack(pixel_values)
|
|
||||||
encoding["labels"] = torch.stack(labels)
|
|
||||||
|
|
||||||
return encoding
|
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if "train" not in dataset:
|
if "train" not in dataset:
|
||||||
@@ -491,7 +389,7 @@ def main():
|
|||||||
dataset["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
|
dataset["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
|
||||||
)
|
)
|
||||||
# Set the training transforms
|
# Set the training transforms
|
||||||
dataset["train"].set_transform(preprocess_train)
|
dataset["train"].set_transform(preprocess_train_batch_fn)
|
||||||
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
if "validation" not in dataset:
|
if "validation" not in dataset:
|
||||||
@@ -501,7 +399,7 @@ def main():
|
|||||||
dataset["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
dataset["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
||||||
)
|
)
|
||||||
# Set the validation transforms
|
# Set the validation transforms
|
||||||
dataset["validation"].set_transform(preprocess_val)
|
dataset["validation"].set_transform(preprocess_val_batch_fn)
|
||||||
|
|
||||||
# Initialize our trainer
|
# Initialize our trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
|||||||
@@ -18,9 +18,10 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import albumentations as A
|
||||||
import datasets
|
import datasets
|
||||||
import evaluate
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -28,12 +29,10 @@ import torch
|
|||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from albumentations.pytorch import ToTensorV2
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
from PIL import Image
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision import transforms
|
|
||||||
from torchvision.transforms import functional
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
@@ -57,123 +56,23 @@ logger = get_logger(__name__)
|
|||||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||||
|
|
||||||
|
|
||||||
def pad_if_smaller(img, size, fill=0):
|
def reduce_labels_transform(labels: np.ndarray, **kwargs) -> np.ndarray:
|
||||||
min_size = min(img.size)
|
"""Set `0` label as with value 255 and then reduce all other labels by 1.
|
||||||
if min_size < size:
|
|
||||||
original_width, original_height = img.size
|
|
||||||
pad_height = size - original_height if original_height < size else 0
|
|
||||||
pad_width = size - original_width if original_width < size else 0
|
|
||||||
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Initial class labels: 0 - background; 1 - road; 2 - car;
|
||||||
|
Transformed class labels: 255 - background; 0 - road; 1 - car;
|
||||||
|
|
||||||
class Compose:
|
**kwargs are required to use this function with albumentations.
|
||||||
def __init__(self, transforms):
|
"""
|
||||||
self.transforms = transforms
|
labels[labels == 0] = 255
|
||||||
|
labels = labels - 1
|
||||||
def __call__(self, image, target):
|
labels[labels == 254] = 255
|
||||||
for t in self.transforms:
|
return labels
|
||||||
image, target = t(image, target)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class Identity:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(self, image, target):
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class Resize:
|
|
||||||
def __init__(self, size):
|
|
||||||
self.size = size
|
|
||||||
|
|
||||||
def __call__(self, image, target):
|
|
||||||
image = functional.resize(image, self.size)
|
|
||||||
target = functional.resize(target, self.size, interpolation=transforms.InterpolationMode.NEAREST)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class RandomResize:
|
|
||||||
def __init__(self, min_size, max_size=None):
|
|
||||||
self.min_size = min_size
|
|
||||||
if max_size is None:
|
|
||||||
max_size = min_size
|
|
||||||
self.max_size = max_size
|
|
||||||
|
|
||||||
def __call__(self, image, target):
|
|
||||||
size = random.randint(self.min_size, self.max_size)
|
|
||||||
image = functional.resize(image, size)
|
|
||||||
target = functional.resize(target, size, interpolation=transforms.InterpolationMode.NEAREST)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class RandomCrop:
|
|
||||||
def __init__(self, size):
|
|
||||||
self.size = size
|
|
||||||
|
|
||||||
def __call__(self, image, target):
|
|
||||||
image = pad_if_smaller(image, self.size)
|
|
||||||
target = pad_if_smaller(target, self.size, fill=255)
|
|
||||||
crop_params = transforms.RandomCrop.get_params(image, (self.size, self.size))
|
|
||||||
image = functional.crop(image, *crop_params)
|
|
||||||
target = functional.crop(target, *crop_params)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class RandomHorizontalFlip:
|
|
||||||
def __init__(self, flip_prob):
|
|
||||||
self.flip_prob = flip_prob
|
|
||||||
|
|
||||||
def __call__(self, image, target):
|
|
||||||
if random.random() < self.flip_prob:
|
|
||||||
image = functional.hflip(image)
|
|
||||||
target = functional.hflip(target)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class PILToTensor:
|
|
||||||
def __call__(self, image, target):
|
|
||||||
image = functional.pil_to_tensor(image)
|
|
||||||
target = torch.as_tensor(np.array(target), dtype=torch.int64)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class ConvertImageDtype:
|
|
||||||
def __init__(self, dtype):
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
def __call__(self, image, target):
|
|
||||||
image = functional.convert_image_dtype(image, self.dtype)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class Normalize:
|
|
||||||
def __init__(self, mean, std):
|
|
||||||
self.mean = mean
|
|
||||||
self.std = std
|
|
||||||
|
|
||||||
def __call__(self, image, target):
|
|
||||||
image = functional.normalize(image, mean=self.mean, std=self.std)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
class ReduceLabels:
|
|
||||||
def __call__(self, image, target):
|
|
||||||
if not isinstance(target, np.ndarray):
|
|
||||||
target = np.array(target).astype(np.uint8)
|
|
||||||
# avoid using underflow conversion
|
|
||||||
target[target == 0] = 255
|
|
||||||
target = target - 1
|
|
||||||
target[target == 254] = 255
|
|
||||||
|
|
||||||
target = Image.fromarray(target)
|
|
||||||
return image, target
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
parser = argparse.ArgumentParser(description="Finetune a transformers model on a image semantic segmentation task")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name_or_path",
|
"--model_name_or_path",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -418,69 +317,58 @@ def main():
|
|||||||
model = AutoModelForSemanticSegmentation.from_pretrained(
|
model = AutoModelForSemanticSegmentation.from_pretrained(
|
||||||
args.model_name_or_path, config=config, trust_remote_code=args.trust_remote_code
|
args.model_name_or_path, config=config, trust_remote_code=args.trust_remote_code
|
||||||
)
|
)
|
||||||
|
# `reduce_labels` is a property of dataset labels, in case we use image_processor
|
||||||
|
# pretrained on another dataset we should override the default setting
|
||||||
|
image_processor.do_reduce_labels = args.reduce_labels
|
||||||
|
|
||||||
# Preprocessing the datasets
|
# Define transforms to be applied to each image and target.
|
||||||
# Define torchvision transforms to be applied to each image + target.
|
|
||||||
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
|
|
||||||
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
|
|
||||||
if "shortest_edge" in image_processor.size:
|
if "shortest_edge" in image_processor.size:
|
||||||
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
|
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
|
||||||
size = (image_processor.size["shortest_edge"], image_processor.size["shortest_edge"])
|
height, width = image_processor.size["shortest_edge"], image_processor.size["shortest_edge"]
|
||||||
else:
|
else:
|
||||||
size = (image_processor.size["height"], image_processor.size["width"])
|
height, width = image_processor.size["height"], image_processor.size["width"]
|
||||||
train_transforms = Compose(
|
train_transforms = A.Compose(
|
||||||
[
|
[
|
||||||
ReduceLabels() if args.reduce_labels else Identity(),
|
A.Lambda(name="reduce_labels", mask=reduce_labels_transform if args.reduce_labels else None, p=1.0),
|
||||||
RandomCrop(size=size),
|
# pad image with 255, because it is ignored by loss
|
||||||
RandomHorizontalFlip(flip_prob=0.5),
|
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=255, p=1.0),
|
||||||
PILToTensor(),
|
A.RandomCrop(height=height, width=width, p=1.0),
|
||||||
ConvertImageDtype(torch.float),
|
A.HorizontalFlip(p=0.5),
|
||||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
|
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
|
||||||
|
ToTensorV2(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
# Define torchvision transform to be applied to each image.
|
val_transforms = A.Compose(
|
||||||
# jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
|
|
||||||
val_transforms = Compose(
|
|
||||||
[
|
[
|
||||||
ReduceLabels() if args.reduce_labels else Identity(),
|
A.Lambda(name="reduce_labels", mask=reduce_labels_transform if args.reduce_labels else None, p=1.0),
|
||||||
Resize(size=size),
|
A.Resize(height=height, width=width, p=1.0),
|
||||||
PILToTensor(),
|
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
|
||||||
ConvertImageDtype(torch.float),
|
ToTensorV2(),
|
||||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def preprocess_train(example_batch):
|
def preprocess_batch(example_batch, transforms: A.Compose):
|
||||||
pixel_values = []
|
pixel_values = []
|
||||||
labels = []
|
labels = []
|
||||||
for image, target in zip(example_batch["image"], example_batch["label"]):
|
for image, target in zip(example_batch["image"], example_batch["label"]):
|
||||||
image, target = train_transforms(image.convert("RGB"), target)
|
transformed = transforms(image=np.array(image.convert("RGB")), mask=np.array(target))
|
||||||
pixel_values.append(image)
|
pixel_values.append(transformed["image"])
|
||||||
labels.append(target)
|
labels.append(transformed["mask"])
|
||||||
|
|
||||||
encoding = {}
|
encoding = {}
|
||||||
encoding["pixel_values"] = torch.stack(pixel_values)
|
encoding["pixel_values"] = torch.stack(pixel_values).to(torch.float)
|
||||||
encoding["labels"] = torch.stack(labels)
|
encoding["labels"] = torch.stack(labels).to(torch.long)
|
||||||
|
|
||||||
return encoding
|
return encoding
|
||||||
|
|
||||||
def preprocess_val(example_batch):
|
# Preprocess function for dataset should have only one input argument,
|
||||||
pixel_values = []
|
# so we use partial to pass transforms
|
||||||
labels = []
|
preprocess_train_batch_fn = partial(preprocess_batch, transforms=train_transforms)
|
||||||
for image, target in zip(example_batch["image"], example_batch["label"]):
|
preprocess_val_batch_fn = partial(preprocess_batch, transforms=val_transforms)
|
||||||
image, target = val_transforms(image.convert("RGB"), target)
|
|
||||||
pixel_values.append(image)
|
|
||||||
labels.append(target)
|
|
||||||
|
|
||||||
encoding = {}
|
|
||||||
encoding["pixel_values"] = torch.stack(pixel_values)
|
|
||||||
encoding["labels"] = torch.stack(labels)
|
|
||||||
|
|
||||||
return encoding
|
|
||||||
|
|
||||||
with accelerator.main_process_first():
|
with accelerator.main_process_first():
|
||||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
train_dataset = dataset["train"].with_transform(preprocess_train_batch_fn)
|
||||||
eval_dataset = dataset["validation"].with_transform(preprocess_val)
|
eval_dataset = dataset["validation"].with_transform(preprocess_val_batch_fn)
|
||||||
|
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size
|
train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size
|
||||||
@@ -726,7 +614,7 @@ def main():
|
|||||||
f"eval_{k}": v.tolist() if isinstance(v, np.ndarray) else v for k, v in eval_metrics.items()
|
f"eval_{k}": v.tolist() if isinstance(v, np.ndarray) else v for k, v in eval_metrics.items()
|
||||||
}
|
}
|
||||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||||
json.dump(all_results, f)
|
json.dump(all_results, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user