From 074d6b75fde01faa0ec39afd6a4158e8e8629f16 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Tue, 31 Jan 2023 15:23:47 +0100 Subject: [PATCH] Simplify column_names in run_clm/mlm (#21382) * simplify column_names in run_clm * simplify column_names in run_mlm * minor --- examples/pytorch/language-modeling/run_clm.py | 10 ++-------- examples/pytorch/language-modeling/run_mlm.py | 10 ++-------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 9a24c55456..12adcdae1b 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -419,15 +419,9 @@ def main(): # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: - if data_args.streaming: - column_names = raw_datasets["train"].features.keys() - else: - column_names = raw_datasets["train"].column_names + column_names = list(raw_datasets["train"].features) else: - if data_args.streaming: - column_names = raw_datasets["validation"].features.keys() - else: - column_names = raw_datasets["validation"].column_names + column_names = list(raw_datasets["validation"].features) text_column_name = "text" if "text" in column_names else column_names[0] # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 8cf76896d1..16dc11abf2 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -405,15 +405,9 @@ def main(): # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: - if data_args.streaming: - column_names = raw_datasets["train"].features.keys() - else: - column_names = raw_datasets["train"].column_names + column_names = list(raw_datasets["train"].features) else: - if data_args.streaming: - column_names = raw_datasets["validation"].features.keys() - else: - column_names = raw_datasets["validation"].column_names + column_names = list(raw_datasets["validation"].features) text_column_name = "text" if "text" in column_names else column_names[0] if data_args.max_seq_length is None: