CodeParrot data pretokenization (#16932)
* add pretokenization arguments * add pretokenization script * add support for pretokenized data * reformat code * fix run command for training * fix model call from config * remove a package * add comments on pretokenization in the readme * remove explicit parallelization Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * update readme Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * update readme -remove username Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * update readme -remove username Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * keep data parallelization * reformat code * reformat code * update readme * reformat code * Update examples/research_projects/codeparrot/README.md Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> Co-authored-by: Loubna ben allal <loubnabenallal@gmail.com>
This commit is contained in:
@@ -0,0 +1,49 @@
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from arguments import PretokenizationArguments
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
def tokenize(example):
|
||||
output = dict()
|
||||
output["input_ids"] = tokenizer(example["content"], truncation=False)["input_ids"]
|
||||
output["ratio_char_token"] = len(example["content"]) / len(output["input_ids"])
|
||||
return output
|
||||
|
||||
|
||||
parser = HfArgumentParser(PretokenizationArguments)
|
||||
args = parser.parse_args()
|
||||
if args.num_workers is None:
|
||||
args.num_workers = multiprocessing.cpu_count()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
|
||||
|
||||
t_start = time.time()
|
||||
ds = load_dataset(args.dataset_name, split="train")
|
||||
print(f"Dataset loaded in {time.time()-t_start:.2f}s")
|
||||
|
||||
t_start = time.time()
|
||||
ds = ds.map(
|
||||
tokenize,
|
||||
num_proc=args.num_workers,
|
||||
remove_columns=[
|
||||
"repo_name",
|
||||
"path",
|
||||
"copies",
|
||||
"size",
|
||||
"content",
|
||||
"license",
|
||||
"hash",
|
||||
"line_mean",
|
||||
"line_max",
|
||||
"alpha_frac",
|
||||
"autogenerated",
|
||||
],
|
||||
)
|
||||
print(f"Dataset tokenized in {time.time()-t_start:.2f}s")
|
||||
|
||||
t_start = time.time()
|
||||
ds.push_to_hub(args.tokenized_data_repo)
|
||||
print(f"Data pushed to the hub in {time.time()-t_start:.2f}s")
|
||||
Reference in New Issue
Block a user