Fix input data file extension in examples (#28741)

This commit is contained in:
Klaus Hipp
2024-01-29 11:06:31 +01:00
committed by GitHub
parent 5649c0cbb8
commit 39fa400969
23 changed files with 49 additions and 23 deletions

View File

@@ -311,11 +311,13 @@ def main():
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.train_file.split(".")[-1]
extension = data_args.test_file.split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.

View File

@@ -339,9 +339,10 @@ def main():
data_files = {}
if args.train_file is not None:
data_files["train"] = args.train_file
extension = args.train_file.split(".")[-1]
if args.validation_file is not None:
data_files["validation"] = args.validation_file
extension = args.train_file.split(".")[-1]
extension = args.validation_file.split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files)
# Trim a number of training examples
if args.debug: