[examples] max samples can't be bigger than the len of dataset (#16501)

* [examples] max samples can't be bigger than then len of dataset

* do tf and flax
This commit is contained in:
Stas Bekman
2022-03-30 12:33:16 -07:00
committed by GitHub
parent c4deb7b3ae
commit a73281e3e4
26 changed files with 154 additions and 77 deletions

View File

@@ -369,7 +369,8 @@ def main():
train_dataset = raw_datasets["train"]
non_label_columns = [feature for feature in train_dataset.features if feature not in ("label", "labels")]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map(
preprocess_function,
@@ -385,7 +386,8 @@ def main():
if not training_args.do_train:
non_label_columns = [feature for feature in eval_dataset.features if feature not in ("label", "labels")]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_dataset.map(
preprocess_function,