Simplify column_names in run_clm/mlm (#21382)

* simplify column_names in run_clm

* simplify column_names in run_mlm

* minor
This commit is contained in:
Quentin Lhoest
2023-01-31 15:23:47 +01:00
committed by GitHub
parent c21298a69b
commit 074d6b75fd
2 changed files with 4 additions and 16 deletions

View File

@@ -419,15 +419,9 @@ def main():
# Preprocessing the datasets.
# First we tokenize all the texts.
if training_args.do_train:
if data_args.streaming:
column_names = raw_datasets["train"].features.keys()
else:
column_names = raw_datasets["train"].column_names
column_names = list(raw_datasets["train"].features)
else:
if data_args.streaming:
column_names = raw_datasets["validation"].features.keys()
else:
column_names = raw_datasets["validation"].column_names
column_names = list(raw_datasets["validation"].features)
text_column_name = "text" if "text" in column_names else column_names[0]
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function

View File

@@ -405,15 +405,9 @@ def main():
# Preprocessing the datasets.
# First we tokenize all the texts.
if training_args.do_train:
if data_args.streaming:
column_names = raw_datasets["train"].features.keys()
else:
column_names = raw_datasets["train"].column_names
column_names = list(raw_datasets["train"].features)
else:
if data_args.streaming:
column_names = raw_datasets["validation"].features.keys()
else:
column_names = raw_datasets["validation"].column_names
column_names = list(raw_datasets["validation"].features)
text_column_name = "text" if "text" in column_names else column_names[0]
if data_args.max_seq_length is None: