[examples] max samples can't be bigger than the len of dataset (#16501)
* [examples] max samples can't be bigger than then len of dataset * do tf and flax
This commit is contained in:
@@ -602,7 +602,8 @@ def main():
|
||||
train_dataset = raw_datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
# We will select sample from whole data if agument is specified
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||
train_dataset = train_dataset.select(range(max_train_samples))
|
||||
# Create train feature from dataset
|
||||
train_dataset = train_dataset.map(
|
||||
prepare_train_features,
|
||||
@@ -613,7 +614,8 @@ def main():
|
||||
)
|
||||
if data_args.max_train_samples is not None:
|
||||
# Number of samples might increase during Feature Creation, We select only specified max samples
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||
train_dataset = train_dataset.select(range(max_train_samples))
|
||||
processed_raw_datasets["train"] = train_dataset
|
||||
|
||||
# Validation preprocessing
|
||||
@@ -669,7 +671,8 @@ def main():
|
||||
eval_examples = raw_datasets["validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
# We will select sample from whole data
|
||||
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
|
||||
max_eval_samples = min(len(eval_examples), data_args.max_eval_samples)
|
||||
eval_examples = eval_examples.select(range(max_eval_samples))
|
||||
# Validation Feature Creation
|
||||
eval_dataset = eval_examples.map(
|
||||
prepare_validation_features,
|
||||
@@ -680,7 +683,8 @@ def main():
|
||||
)
|
||||
if data_args.max_eval_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
||||
processed_raw_datasets["validation"] = eval_dataset
|
||||
|
||||
if training_args.do_predict:
|
||||
@@ -700,7 +704,8 @@ def main():
|
||||
)
|
||||
if data_args.max_predict_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
|
||||
predict_dataset = predict_dataset.select(range(max_predict_samples))
|
||||
processed_raw_datasets["test"] = predict_dataset
|
||||
# endregion
|
||||
|
||||
|
||||
Reference in New Issue
Block a user