Simplify column_names in run_clm/mlm (#21382)

* simplify column_names in run_clm

* simplify column_names in run_mlm

* minor
This commit is contained in:
Quentin Lhoest
2023-01-31 15:23:47 +01:00
committed by GitHub
parent c21298a69b
commit 074d6b75fd
2 changed files with 4 additions and 16 deletions

View File

@@ -419,15 +419,9 @@ def main():
# Preprocessing the datasets. # Preprocessing the datasets.
# First we tokenize all the texts. # First we tokenize all the texts.
if training_args.do_train: if training_args.do_train:
if data_args.streaming: column_names = list(raw_datasets["train"].features)
column_names = raw_datasets["train"].features.keys()
else:
column_names = raw_datasets["train"].column_names
else: else:
if data_args.streaming: column_names = list(raw_datasets["validation"].features)
column_names = raw_datasets["validation"].features.keys()
else:
column_names = raw_datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0] text_column_name = "text" if "text" in column_names else column_names[0]
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function

View File

@@ -405,15 +405,9 @@ def main():
# Preprocessing the datasets. # Preprocessing the datasets.
# First we tokenize all the texts. # First we tokenize all the texts.
if training_args.do_train: if training_args.do_train:
if data_args.streaming: column_names = list(raw_datasets["train"].features)
column_names = raw_datasets["train"].features.keys()
else:
column_names = raw_datasets["train"].column_names
else: else:
if data_args.streaming: column_names = list(raw_datasets["validation"].features)
column_names = raw_datasets["validation"].features.keys()
else:
column_names = raw_datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0] text_column_name = "text" if "text" in column_names else column_names[0]
if data_args.max_seq_length is None: if data_args.max_seq_length is None: