[examples] max samples can't be bigger than the len of dataset (#16501)

* [examples] max samples can't be bigger than then len of dataset

* do tf and flax
This commit is contained in:
Stas Bekman
2022-03-30 12:33:16 -07:00
committed by GitHub
parent c4deb7b3ae
commit a73281e3e4
26 changed files with 154 additions and 77 deletions

View File

@@ -613,7 +613,8 @@ def main():
raise ValueError("--do_train requires a train dataset")
train_dataset = dataset["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
# remove problematic examples
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
# instead here.)
@@ -646,7 +647,8 @@ def main():
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = dataset["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
# remove problematic examples
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
# instead here.)
@@ -675,7 +677,8 @@ def main():
raise ValueError("--do_predict requires a test dataset")
predict_dataset = dataset["test"]
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))
# remove problematic examples
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
# instead here.)

View File

@@ -527,14 +527,16 @@ def main():
raise ValueError("--do_train requires a train dataset")
train_dataset = lm_datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
if training_args.do_eval:
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = lm_datasets["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()

View File

@@ -602,7 +602,8 @@ def main():
train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
# We will select sample from whole data if agument is specified
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
# Create train feature from dataset
train_dataset = train_dataset.map(
prepare_train_features,
@@ -613,7 +614,8 @@ def main():
)
if data_args.max_train_samples is not None:
# Number of samples might increase during Feature Creation, We select only specified max samples
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
processed_raw_datasets["train"] = train_dataset
# Validation preprocessing
@@ -669,7 +671,8 @@ def main():
eval_examples = raw_datasets["validation"]
if data_args.max_eval_samples is not None:
# We will select sample from whole data
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_examples), data_args.max_eval_samples)
eval_examples = eval_examples.select(range(max_eval_samples))
# Validation Feature Creation
eval_dataset = eval_examples.map(
prepare_validation_features,
@@ -680,7 +683,8 @@ def main():
)
if data_args.max_eval_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
processed_raw_datasets["validation"] = eval_dataset
if training_args.do_predict:
@@ -700,7 +704,8 @@ def main():
)
if data_args.max_predict_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))
processed_raw_datasets["test"] = predict_dataset
# endregion

View File

@@ -547,7 +547,8 @@ def main():
raise ValueError("--do_train requires a train dataset")
train_dataset = dataset["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
@@ -563,7 +564,8 @@ def main():
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = dataset["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
@@ -579,7 +581,8 @@ def main():
raise ValueError("--do_predict requires a test dataset")
predict_dataset = dataset["test"]
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))
predict_dataset = predict_dataset.map(
preprocess_function,
batched=True,