Improve image classification example (#16585)

* Improve README

* Make dataset_name argument optional

* Improve local data

* Fix bug

* Improve README some more

* Apply suggestions from code review

* Improve README

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-04-14 18:10:52 +02:00
committed by GitHub
parent 3e4eec47f5
commit 048443db86
2 changed files with 114 additions and 67 deletions

View File

@@ -72,13 +72,15 @@ def pil_loader(path: str):
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify
them on the command line.
"""
dataset_name: Optional[str] = field(
default="nateraw/image-folder", metadata={"help": "Name of a dataset from the datasets package"}
default=None,
metadata={
"help": "Name of a dataset from the hub (could be your own, possibly private dataset hosted on the hub)."
},
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
@@ -104,12 +106,10 @@ class DataTrainingArguments:
)
def __post_init__(self):
data_files = dict()
if self.train_dir is not None:
data_files["train"] = self.train_dir
if self.validation_dir is not None:
data_files["val"] = self.validation_dir
self.data_files = data_files if data_files else None
if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None):
raise ValueError(
"You must specify either a dataset name from the hub or a train and/or validation directory."
)
@dataclass
@@ -201,25 +201,37 @@ def main():
)
# Initialize our dataset and prepare it for the 'image-classification' task.
ds = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
data_files=data_args.data_files,
cache_dir=model_args.cache_dir,
task="image-classification",
use_auth_token=True if model_args.use_auth_token else None,
)
if data_args.dataset_name is not None:
dataset = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
task="image-classification",
use_auth_token=True if model_args.use_auth_token else None,
)
else:
data_files = {}
if data_args.train_dir is not None:
data_files["train"] = os.path.join(data_args.train_dir, "**")
if data_args.validation_dir is not None:
data_files["validation"] = os.path.join(data_args.validation_dir, "**")
dataset = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=model_args.cache_dir,
task="image-classification",
)
# If we don't have a validation split, split off a percentage of train as validation.
data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split
data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
split = ds["train"].train_test_split(data_args.train_val_split)
ds["train"] = split["train"]
ds["validation"] = split["test"]
split = dataset["train"].train_test_split(data_args.train_val_split)
dataset["train"] = split["train"]
dataset["validation"] = split["test"]
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = ds["train"].features["labels"].names
labels = dataset["train"].features["labels"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
@@ -291,29 +303,31 @@ def main():
return example_batch
if training_args.do_train:
if "train" not in ds:
if "train" not in dataset:
raise ValueError("--do_train requires a train dataset")
if data_args.max_train_samples is not None:
ds["train"] = ds["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
dataset["train"] = (
dataset["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
)
# Set the training transforms
ds["train"].set_transform(train_transforms)
dataset["train"].set_transform(train_transforms)
if training_args.do_eval:
if "validation" not in ds:
if "validation" not in dataset:
raise ValueError("--do_eval requires a validation dataset")
if data_args.max_eval_samples is not None:
ds["validation"] = (
ds["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
dataset["validation"] = (
dataset["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
# Set the validation transforms
ds["validation"].set_transform(val_transforms)
dataset["validation"].set_transform(val_transforms)
# Initalize our trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=ds["train"] if training_args.do_train else None,
eval_dataset=ds["validation"] if training_args.do_eval else None,
train_dataset=dataset["train"] if training_args.do_train else None,
eval_dataset=dataset["validation"] if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=feature_extractor,
data_collator=collate_fn,
@@ -343,7 +357,7 @@ def main():
"finetuned_from": model_args.model_name_or_path,
"tasks": "image-classification",
"dataset": data_args.dataset_name,
"tags": ["image-classification"],
"tags": ["image-classification", "vision"],
}
if training_args.push_to_hub:
trainer.push_to_hub(**kwargs)