[examples] max samples can't be bigger than the len of dataset (#16501)

* [examples] max samples can't be bigger than then len of dataset

* do tf and flax
This commit is contained in:
Stas Bekman
2022-03-30 12:33:16 -07:00
committed by GitHub
parent c4deb7b3ae
commit a73281e3e4
26 changed files with 154 additions and 77 deletions

View File

@@ -404,7 +404,8 @@ def main():
raise ValueError("--do_train requires a train dataset")
train_dataset = dataset["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
train_dataset = train_dataset.filter(
filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
@@ -426,7 +427,8 @@ def main():
raise ValueError("--do_eval requires a train validation")
eval_dataset = dataset["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
eval_dataset = eval_dataset.filter(
filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
@@ -448,7 +450,8 @@ def main():
raise ValueError("--do_predict requires a test dataset")
test_dataset = dataset["test"]
if data_args.max_eval_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(test_dataset), data_args.max_eval_samples)
test_dataset = test_dataset.select(range(max_eval_samples))
test_dataset = test_dataset.filter(
filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers