From 71d18d08319db50a1a5bd022aed54d2613eec3b9 Mon Sep 17 00:00:00 2001 From: Kenneth Enevoldsen Date: Mon, 16 May 2022 13:40:27 +0200 Subject: [PATCH] fixed bug in run_mlm_flax_stream.py (#17203) * fixed bug run_mlm_flax_stream.py Fixed bug caused by an update to tokenizer keys introduced in recent transformers versions (between `4.6.2` and `4.18.0`) where additional keys were introduced to the tokenizer output. * Update run_mlm_flax_stream.py * adding missing paranthesis * formatted to black * remove cols from dataset instead * reformat to black * moved rem. columns to map * formatted to black Co-authored-by: KennethEnevoldsen --- .../dataset-streaming/run_mlm_flax_stream.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py index c64979d40f..f0f3e873d8 100755 --- a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py +++ b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py @@ -288,8 +288,10 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length): tokenized_samples = next(train_iterator) i += len(tokenized_samples["input_ids"]) - # concatenate tokenized samples to list - samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()} + # concatenate tokenized samples to list (excluding "id" and "text") + samples = { + k: samples[k] + tokenized_samples[k] for k in ["input_ids", "attention_mask", "special_tokens_mask"] + } # Concatenated tokens are split to lists of length `max_seq_length`. # Note that remainedr of % max_seq_length are thrown away. @@ -407,10 +409,7 @@ if __name__ == "__main__": def tokenize_function(examples): return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True) - tokenized_datasets = dataset.map( - tokenize_function, - batched=True, - ) + tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=list(dataset.features.keys())) shuffle_seed = training_args.seed tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)