Fix and simplify semantic-segmentation example (#30145)
* Remove unused augmentation * Fix pad_if_smaller() and remove unused augmentation * Add indentation * Fix requirements * Update dataset use instructions * Replace transforms with albumentations * Replace identity transform with None * Fixing formatting * Fixed comment place
This commit is contained in:
committed by
GitHub
parent
41579763ee
commit
56d001b26f
@@ -25,3 +25,4 @@ torchaudio
|
||||
jiwer
|
||||
librosa
|
||||
evaluate >= 0.2.0
|
||||
albumentations
|
||||
|
||||
@@ -97,6 +97,10 @@ The script leverages the [🤗 Trainer API](https://huggingface.co/docs/transfor
|
||||
|
||||
Here we show how to fine-tune a [SegFormer](https://huggingface.co/nvidia/mit-b0) model on the [segments/sidewalk-semantic](https://huggingface.co/datasets/segments/sidewalk-semantic) dataset:
|
||||
|
||||
In order to use `segments/sidewalk-semantic`:
|
||||
- Log in to Hugging Face with `huggingface-cli login` (token can be accessed [here](https://huggingface.co/settings/tokens)).
|
||||
- Accept terms of use for `sidewalk-semantic` on [dataset page](https://huggingface.co/datasets/segments/sidewalk-semantic).
|
||||
|
||||
```bash
|
||||
python run_semantic_segmentation.py \
|
||||
--model_name_or_path nvidia/mit-b0 \
|
||||
@@ -105,7 +109,6 @@ python run_semantic_segmentation.py \
|
||||
--remove_unused_columns False \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--evaluation_strategy steps \
|
||||
--push_to_hub \
|
||||
--push_to_hub_model_id segformer-finetuned-sidewalk-10k-steps \
|
||||
--max_steps 10000 \
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
git://github.com/huggingface/accelerate.git
|
||||
datasets >= 2.0.0
|
||||
torch >= 1.3
|
||||
accelerate
|
||||
evaluate
|
||||
Pillow
|
||||
albumentations
|
||||
@@ -16,21 +16,20 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import albumentations as A
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
@@ -57,118 +56,19 @@ check_min_version("4.40.0.dev0")
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
||||
def pad_if_smaller(img, size, fill=0):
|
||||
size = (size, size) if isinstance(size, int) else size
|
||||
original_width, original_height = img.size
|
||||
pad_height = size[1] - original_height if original_height < size[1] else 0
|
||||
pad_width = size[0] - original_width if original_width < size[0] else 0
|
||||
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
|
||||
return img
|
||||
def reduce_labels_transform(labels: np.ndarray, **kwargs) -> np.ndarray:
|
||||
"""Set `0` label as with value 255 and then reduce all other labels by 1.
|
||||
|
||||
Example:
|
||||
Initial class labels: 0 - background; 1 - road; 2 - car;
|
||||
Transformed class labels: 255 - background; 0 - road; 1 - car;
|
||||
|
||||
class Compose:
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, image, target):
|
||||
for t in self.transforms:
|
||||
image, target = t(image, target)
|
||||
return image, target
|
||||
|
||||
|
||||
class Identity:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, image, target):
|
||||
return image, target
|
||||
|
||||
|
||||
class Resize:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = functional.resize(image, self.size)
|
||||
target = functional.resize(target, self.size, interpolation=transforms.InterpolationMode.NEAREST)
|
||||
return image, target
|
||||
|
||||
|
||||
class RandomResize:
|
||||
def __init__(self, min_size, max_size=None):
|
||||
self.min_size = min_size
|
||||
if max_size is None:
|
||||
max_size = min_size
|
||||
self.max_size = max_size
|
||||
|
||||
def __call__(self, image, target):
|
||||
size = random.randint(self.min_size, self.max_size)
|
||||
image = functional.resize(image, size)
|
||||
target = functional.resize(target, size, interpolation=transforms.InterpolationMode.NEAREST)
|
||||
return image, target
|
||||
|
||||
|
||||
class RandomCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size if isinstance(size, tuple) else (size, size)
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = pad_if_smaller(image, self.size)
|
||||
target = pad_if_smaller(target, self.size, fill=255)
|
||||
crop_params = transforms.RandomCrop.get_params(image, self.size)
|
||||
image = functional.crop(image, *crop_params)
|
||||
target = functional.crop(target, *crop_params)
|
||||
return image, target
|
||||
|
||||
|
||||
class RandomHorizontalFlip:
|
||||
def __init__(self, flip_prob):
|
||||
self.flip_prob = flip_prob
|
||||
|
||||
def __call__(self, image, target):
|
||||
if random.random() < self.flip_prob:
|
||||
image = functional.hflip(image)
|
||||
target = functional.hflip(target)
|
||||
return image, target
|
||||
|
||||
|
||||
class PILToTensor:
|
||||
def __call__(self, image, target):
|
||||
image = functional.pil_to_tensor(image)
|
||||
target = torch.as_tensor(np.array(target), dtype=torch.int64)
|
||||
return image, target
|
||||
|
||||
|
||||
class ConvertImageDtype:
|
||||
def __init__(self, dtype):
|
||||
self.dtype = dtype
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = functional.convert_image_dtype(image, self.dtype)
|
||||
return image, target
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = functional.normalize(image, mean=self.mean, std=self.std)
|
||||
return image, target
|
||||
|
||||
|
||||
class ReduceLabels:
|
||||
def __call__(self, image, target):
|
||||
if not isinstance(target, np.ndarray):
|
||||
target = np.array(target).astype(np.uint8)
|
||||
# avoid using underflow conversion
|
||||
target[target == 0] = 255
|
||||
target = target - 1
|
||||
target[target == 254] = 255
|
||||
|
||||
target = Image.fromarray(target)
|
||||
return image, target
|
||||
**kwargs are required to use this function with albumentations.
|
||||
"""
|
||||
labels[labels == 0] = 255
|
||||
labels = labels - 1
|
||||
labels[labels == 254] = 255
|
||||
return labels
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -365,7 +265,7 @@ def main():
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
label2id = {v: str(k) for k, v in id2label.items()}
|
||||
|
||||
# Load the mean IoU metric from the datasets package
|
||||
# Load the mean IoU metric from the evaluate package
|
||||
metric = evaluate.load("mean_iou", cache_dir=model_args.cache_dir)
|
||||
|
||||
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
|
||||
@@ -424,64 +324,62 @@ def main():
|
||||
token=model_args.token,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
# `reduce_labels` is a property of dataset labels, in case we use image_processor
|
||||
# pretrained on another dataset we should override the default setting
|
||||
image_processor.do_reduce_labels = data_args.reduce_labels
|
||||
|
||||
# Define torchvision transforms to be applied to each image + target.
|
||||
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
|
||||
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
|
||||
# Define transforms to be applied to each image and target.
|
||||
if "shortest_edge" in image_processor.size:
|
||||
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
|
||||
size = (image_processor.size["shortest_edge"], image_processor.size["shortest_edge"])
|
||||
height, width = image_processor.size["shortest_edge"], image_processor.size["shortest_edge"]
|
||||
else:
|
||||
size = (image_processor.size["height"], image_processor.size["width"])
|
||||
train_transforms = Compose(
|
||||
height, width = image_processor.size["height"], image_processor.size["width"]
|
||||
train_transforms = A.Compose(
|
||||
[
|
||||
ReduceLabels() if data_args.reduce_labels else Identity(),
|
||||
RandomCrop(size=size),
|
||||
RandomHorizontalFlip(flip_prob=0.5),
|
||||
PILToTensor(),
|
||||
ConvertImageDtype(torch.float),
|
||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
|
||||
A.Lambda(
|
||||
name="reduce_labels",
|
||||
mask=reduce_labels_transform if data_args.reduce_labels else None,
|
||||
p=1.0,
|
||||
),
|
||||
# pad image with 255, because it is ignored by loss
|
||||
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=255, p=1.0),
|
||||
A.RandomCrop(height=height, width=width, p=1.0),
|
||||
A.HorizontalFlip(p=0.5),
|
||||
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
# Define torchvision transform to be applied to each image.
|
||||
# jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
|
||||
val_transforms = Compose(
|
||||
val_transforms = A.Compose(
|
||||
[
|
||||
ReduceLabels() if data_args.reduce_labels else Identity(),
|
||||
Resize(size=size),
|
||||
PILToTensor(),
|
||||
ConvertImageDtype(torch.float),
|
||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
|
||||
A.Lambda(
|
||||
name="reduce_labels",
|
||||
mask=reduce_labels_transform if data_args.reduce_labels else None,
|
||||
p=1.0,
|
||||
),
|
||||
A.Resize(height=height, width=width, p=1.0),
|
||||
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_train(example_batch):
|
||||
def preprocess_batch(example_batch, transforms: A.Compose):
|
||||
pixel_values = []
|
||||
labels = []
|
||||
for image, target in zip(example_batch["image"], example_batch["label"]):
|
||||
image, target = train_transforms(image.convert("RGB"), target)
|
||||
pixel_values.append(image)
|
||||
labels.append(target)
|
||||
transformed = transforms(image=np.array(image.convert("RGB")), mask=np.array(target))
|
||||
pixel_values.append(transformed["image"])
|
||||
labels.append(transformed["mask"])
|
||||
|
||||
encoding = {}
|
||||
encoding["pixel_values"] = torch.stack(pixel_values)
|
||||
encoding["labels"] = torch.stack(labels)
|
||||
encoding["pixel_values"] = torch.stack(pixel_values).to(torch.float)
|
||||
encoding["labels"] = torch.stack(labels).to(torch.long)
|
||||
|
||||
return encoding
|
||||
|
||||
def preprocess_val(example_batch):
|
||||
pixel_values = []
|
||||
labels = []
|
||||
for image, target in zip(example_batch["image"], example_batch["label"]):
|
||||
image, target = val_transforms(image.convert("RGB"), target)
|
||||
pixel_values.append(image)
|
||||
labels.append(target)
|
||||
|
||||
encoding = {}
|
||||
encoding["pixel_values"] = torch.stack(pixel_values)
|
||||
encoding["labels"] = torch.stack(labels)
|
||||
|
||||
return encoding
|
||||
# Preprocess function for dataset should have only one argument,
|
||||
# so we use partial to pass the transforms
|
||||
preprocess_train_batch_fn = partial(preprocess_batch, transforms=train_transforms)
|
||||
preprocess_val_batch_fn = partial(preprocess_batch, transforms=val_transforms)
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in dataset:
|
||||
@@ -491,7 +389,7 @@ def main():
|
||||
dataset["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
|
||||
)
|
||||
# Set the training transforms
|
||||
dataset["train"].set_transform(preprocess_train)
|
||||
dataset["train"].set_transform(preprocess_train_batch_fn)
|
||||
|
||||
if training_args.do_eval:
|
||||
if "validation" not in dataset:
|
||||
@@ -501,7 +399,7 @@ def main():
|
||||
dataset["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
||||
)
|
||||
# Set the validation transforms
|
||||
dataset["validation"].set_transform(preprocess_val)
|
||||
dataset["validation"].set_transform(preprocess_val_batch_fn)
|
||||
|
||||
# Initialize our trainer
|
||||
trainer = Trainer(
|
||||
|
||||
@@ -18,9 +18,10 @@ import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import albumentations as A
|
||||
import datasets
|
||||
import evaluate
|
||||
import numpy as np
|
||||
@@ -28,12 +29,10 @@ import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
@@ -57,123 +56,23 @@ logger = get_logger(__name__)
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
||||
def pad_if_smaller(img, size, fill=0):
|
||||
min_size = min(img.size)
|
||||
if min_size < size:
|
||||
original_width, original_height = img.size
|
||||
pad_height = size - original_height if original_height < size else 0
|
||||
pad_width = size - original_width if original_width < size else 0
|
||||
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
|
||||
return img
|
||||
def reduce_labels_transform(labels: np.ndarray, **kwargs) -> np.ndarray:
|
||||
"""Set `0` label as with value 255 and then reduce all other labels by 1.
|
||||
|
||||
Example:
|
||||
Initial class labels: 0 - background; 1 - road; 2 - car;
|
||||
Transformed class labels: 255 - background; 0 - road; 1 - car;
|
||||
|
||||
class Compose:
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, image, target):
|
||||
for t in self.transforms:
|
||||
image, target = t(image, target)
|
||||
return image, target
|
||||
|
||||
|
||||
class Identity:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, image, target):
|
||||
return image, target
|
||||
|
||||
|
||||
class Resize:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = functional.resize(image, self.size)
|
||||
target = functional.resize(target, self.size, interpolation=transforms.InterpolationMode.NEAREST)
|
||||
return image, target
|
||||
|
||||
|
||||
class RandomResize:
|
||||
def __init__(self, min_size, max_size=None):
|
||||
self.min_size = min_size
|
||||
if max_size is None:
|
||||
max_size = min_size
|
||||
self.max_size = max_size
|
||||
|
||||
def __call__(self, image, target):
|
||||
size = random.randint(self.min_size, self.max_size)
|
||||
image = functional.resize(image, size)
|
||||
target = functional.resize(target, size, interpolation=transforms.InterpolationMode.NEAREST)
|
||||
return image, target
|
||||
|
||||
|
||||
class RandomCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = pad_if_smaller(image, self.size)
|
||||
target = pad_if_smaller(target, self.size, fill=255)
|
||||
crop_params = transforms.RandomCrop.get_params(image, (self.size, self.size))
|
||||
image = functional.crop(image, *crop_params)
|
||||
target = functional.crop(target, *crop_params)
|
||||
return image, target
|
||||
|
||||
|
||||
class RandomHorizontalFlip:
|
||||
def __init__(self, flip_prob):
|
||||
self.flip_prob = flip_prob
|
||||
|
||||
def __call__(self, image, target):
|
||||
if random.random() < self.flip_prob:
|
||||
image = functional.hflip(image)
|
||||
target = functional.hflip(target)
|
||||
return image, target
|
||||
|
||||
|
||||
class PILToTensor:
|
||||
def __call__(self, image, target):
|
||||
image = functional.pil_to_tensor(image)
|
||||
target = torch.as_tensor(np.array(target), dtype=torch.int64)
|
||||
return image, target
|
||||
|
||||
|
||||
class ConvertImageDtype:
|
||||
def __init__(self, dtype):
|
||||
self.dtype = dtype
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = functional.convert_image_dtype(image, self.dtype)
|
||||
return image, target
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = functional.normalize(image, mean=self.mean, std=self.std)
|
||||
return image, target
|
||||
|
||||
|
||||
class ReduceLabels:
|
||||
def __call__(self, image, target):
|
||||
if not isinstance(target, np.ndarray):
|
||||
target = np.array(target).astype(np.uint8)
|
||||
# avoid using underflow conversion
|
||||
target[target == 0] = 255
|
||||
target = target - 1
|
||||
target[target == 254] = 255
|
||||
|
||||
target = Image.fromarray(target)
|
||||
return image, target
|
||||
**kwargs are required to use this function with albumentations.
|
||||
"""
|
||||
labels[labels == 0] = 255
|
||||
labels = labels - 1
|
||||
labels[labels == 254] = 255
|
||||
return labels
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a image semantic segmentation task")
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
@@ -418,69 +317,58 @@ def main():
|
||||
model = AutoModelForSemanticSegmentation.from_pretrained(
|
||||
args.model_name_or_path, config=config, trust_remote_code=args.trust_remote_code
|
||||
)
|
||||
# `reduce_labels` is a property of dataset labels, in case we use image_processor
|
||||
# pretrained on another dataset we should override the default setting
|
||||
image_processor.do_reduce_labels = args.reduce_labels
|
||||
|
||||
# Preprocessing the datasets
|
||||
# Define torchvision transforms to be applied to each image + target.
|
||||
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
|
||||
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
|
||||
# Define transforms to be applied to each image and target.
|
||||
if "shortest_edge" in image_processor.size:
|
||||
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
|
||||
size = (image_processor.size["shortest_edge"], image_processor.size["shortest_edge"])
|
||||
height, width = image_processor.size["shortest_edge"], image_processor.size["shortest_edge"]
|
||||
else:
|
||||
size = (image_processor.size["height"], image_processor.size["width"])
|
||||
train_transforms = Compose(
|
||||
height, width = image_processor.size["height"], image_processor.size["width"]
|
||||
train_transforms = A.Compose(
|
||||
[
|
||||
ReduceLabels() if args.reduce_labels else Identity(),
|
||||
RandomCrop(size=size),
|
||||
RandomHorizontalFlip(flip_prob=0.5),
|
||||
PILToTensor(),
|
||||
ConvertImageDtype(torch.float),
|
||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
|
||||
A.Lambda(name="reduce_labels", mask=reduce_labels_transform if args.reduce_labels else None, p=1.0),
|
||||
# pad image with 255, because it is ignored by loss
|
||||
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=255, p=1.0),
|
||||
A.RandomCrop(height=height, width=width, p=1.0),
|
||||
A.HorizontalFlip(p=0.5),
|
||||
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
# Define torchvision transform to be applied to each image.
|
||||
# jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
|
||||
val_transforms = Compose(
|
||||
val_transforms = A.Compose(
|
||||
[
|
||||
ReduceLabels() if args.reduce_labels else Identity(),
|
||||
Resize(size=size),
|
||||
PILToTensor(),
|
||||
ConvertImageDtype(torch.float),
|
||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
|
||||
A.Lambda(name="reduce_labels", mask=reduce_labels_transform if args.reduce_labels else None, p=1.0),
|
||||
A.Resize(height=height, width=width, p=1.0),
|
||||
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_train(example_batch):
|
||||
def preprocess_batch(example_batch, transforms: A.Compose):
|
||||
pixel_values = []
|
||||
labels = []
|
||||
for image, target in zip(example_batch["image"], example_batch["label"]):
|
||||
image, target = train_transforms(image.convert("RGB"), target)
|
||||
pixel_values.append(image)
|
||||
labels.append(target)
|
||||
transformed = transforms(image=np.array(image.convert("RGB")), mask=np.array(target))
|
||||
pixel_values.append(transformed["image"])
|
||||
labels.append(transformed["mask"])
|
||||
|
||||
encoding = {}
|
||||
encoding["pixel_values"] = torch.stack(pixel_values)
|
||||
encoding["labels"] = torch.stack(labels)
|
||||
encoding["pixel_values"] = torch.stack(pixel_values).to(torch.float)
|
||||
encoding["labels"] = torch.stack(labels).to(torch.long)
|
||||
|
||||
return encoding
|
||||
|
||||
def preprocess_val(example_batch):
|
||||
pixel_values = []
|
||||
labels = []
|
||||
for image, target in zip(example_batch["image"], example_batch["label"]):
|
||||
image, target = val_transforms(image.convert("RGB"), target)
|
||||
pixel_values.append(image)
|
||||
labels.append(target)
|
||||
|
||||
encoding = {}
|
||||
encoding["pixel_values"] = torch.stack(pixel_values)
|
||||
encoding["labels"] = torch.stack(labels)
|
||||
|
||||
return encoding
|
||||
# Preprocess function for dataset should have only one input argument,
|
||||
# so we use partial to pass transforms
|
||||
preprocess_train_batch_fn = partial(preprocess_batch, transforms=train_transforms)
|
||||
preprocess_val_batch_fn = partial(preprocess_batch, transforms=val_transforms)
|
||||
|
||||
with accelerator.main_process_first():
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
eval_dataset = dataset["validation"].with_transform(preprocess_val)
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train_batch_fn)
|
||||
eval_dataset = dataset["validation"].with_transform(preprocess_val_batch_fn)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size
|
||||
@@ -726,7 +614,7 @@ def main():
|
||||
f"eval_{k}": v.tolist() if isinstance(v, np.ndarray) else v for k, v in eval_metrics.items()
|
||||
}
|
||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||
json.dump(all_results, f)
|
||||
json.dump(all_results, f, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user