[examples] max samples can't be bigger than the len of dataset (#16501)
* [examples] max samples can't be bigger than then len of dataset * do tf and flax
This commit is contained in:
@@ -613,7 +613,8 @@ def main():
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = dataset["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||
train_dataset = train_dataset.select(range(max_train_samples))
|
||||
# remove problematic examples
|
||||
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
|
||||
# instead here.)
|
||||
@@ -646,7 +647,8 @@ def main():
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = dataset["validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
||||
# remove problematic examples
|
||||
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
|
||||
# instead here.)
|
||||
@@ -675,7 +677,8 @@ def main():
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
predict_dataset = dataset["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
|
||||
predict_dataset = predict_dataset.select(range(max_predict_samples))
|
||||
# remove problematic examples
|
||||
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
|
||||
# instead here.)
|
||||
|
||||
Reference in New Issue
Block a user