Fix input data file extension in examples (#28741)

This commit is contained in:
Klaus Hipp
2024-01-29 11:06:31 +01:00
committed by GitHub
parent 5649c0cbb8
commit 39fa400969
23 changed files with 49 additions and 23 deletions

View File

@@ -297,9 +297,10 @@ def main():
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
extension = data_args.validation_file.split(".")[-1]
if extension == "txt":
extension = "text"
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)

View File

@@ -285,9 +285,10 @@ def main():
data_files = {}
if args.train_file is not None:
data_files["train"] = args.train_file
extension = args.train_file.split(".")[-1]
if args.validation_file is not None:
data_files["validation"] = args.validation_file
extension = args.train_file.split(".")[-1]
extension = args.validation_file.split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files)
# Trim a number of training examples
if args.debug:

View File

@@ -271,9 +271,10 @@ def main():
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
extension = data_args.validation_file.split(".")[-1]
if extension == "txt":
extension = "text"
datasets = load_dataset(extension, data_files=data_files)

View File

@@ -517,9 +517,10 @@ if __name__ == "__main__":
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
extension = data_args.validation_file.split(".")[-1]
if extension == "txt":
extension = "text"
datasets = load_dataset(extension, data_files=data_files)