From cd7961b632c66299a700d8d54912f63a0348d58d Mon Sep 17 00:00:00 2001 From: Nicholas Broad Date: Mon, 14 Jun 2021 08:11:13 -0400 Subject: [PATCH] Use text_column_name variable instead of "text" (#12132) * Use text_column_name variable instead of "text" `text_column_name` was already defined above where I made the changes and it was also used below where I made changes. This is a very minor change. If a dataset does not use "text" as the column name, then the `tokenize_function` will now use whatever column is assigned to `text_column_name`. `text_column_name` is just the first column name if "text" is not a column name. It makes the function a little more robust, though I would assume that 90% + of datasets use "text" anyway. * black formatting * make style Co-authored-by: Nicholas Broad --- examples/pytorch/language-modeling/run_mlm.py | 6 ++++-- examples/pytorch/language-modeling/run_mlm_no_trainer.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 929a9d6ff9..7612e05226 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -345,9 +345,11 @@ def main(): def tokenize_function(examples): # Remove empty lines - examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()] + examples[text_column_name] = [ + line for line in examples[text_column_name] if len(line) > 0 and not line.isspace() + ] return tokenizer( - examples["text"], + examples[text_column_name], padding=padding, truncation=True, max_length=max_seq_length, diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py index 1731b244da..27e61056df 100755 --- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py @@ -327,9 +327,11 @@ def main(): def tokenize_function(examples): # Remove empty lines - examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()] + examples[text_column_name] = [ + line for line in examples[text_column_name] if len(line) > 0 and not line.isspace() + ] return tokenizer( - examples["text"], + examples[text_column_name], padding=padding, truncation=True, max_length=max_seq_length,