[CodeParrot] Near-deduplication with jaccard similarity (#17054)

* deduplication draft

* update style

* update style test

* dummy test main

* rename modules

* rename functions

* return extremes in deduplicate_clusters

* update style

* cast str for gzip

* update doc string

* time processing

* use dataset map to compute minhash

* fill value for short token

* remove da map method

* update style

* use share object to multiprocess

* update style

* use f-string and minor fix

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com>

* update style

* use module parameters

* change ds_dedup to ds_filter

* save ds_dedup

* mv test to script tests

* make jaccard threshold a parameter of deduplicate_dataset

* update style

* add doc strings

* update style

* add doc string for DuplicationIndex

* save files into data dir

* update readme

* Update examples/research_projects/codeparrot/README.md

Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com>

* make near deduplication optional

* move near deduplication in README

* Update examples/research_projects/codeparrot/README.md

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* use f string

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com>
This commit is contained in:
Jia LI
2022-06-21 14:23:36 +02:00
committed by GitHub
parent eb16be415a
commit da2bd2ae96
7 changed files with 334 additions and 5 deletions

View File

@@ -1,14 +1,17 @@
import gzip
import hashlib
import json
import multiprocessing
import os
import shutil
import time
from pathlib import Path
import numpy as np
from datasets import load_dataset
from arguments import PreprocessingArguments
from minhash_deduplication import deduplicate_dataset
from transformers import AutoTokenizer, HfArgumentParser
@@ -146,7 +149,7 @@ def filter(example, uniques, args):
def compress_file(file_path):
"""Compress a file with g-zip."""
with open(file_path, "rb") as f_in:
with gzip.open(file_path + ".gz", "wb", compresslevel=6) as f_out:
with gzip.open(str(file_path) + ".gz", "wb", compresslevel=6) as f_out:
shutil.copyfileobj(f_in, f_out)
os.unlink(file_path)
@@ -179,12 +182,29 @@ ds_filter = ds.filter(filter, fn_kwargs={"uniques": uniques, "args": args})
print(f"Time to filter dataset: {time.time()-t_start:.2f}")
print(f"Size of filtered dataset: {len(ds_filter)}")
# Deduplicate with minhash and jaccard similarity
if args.near_deduplication:
t_start = time.time()
ds_filter, duplicate_clusters = deduplicate_dataset(ds_filter, args.jaccard_threshold)
print(f"Time to deduplicate dataset: {time.time()-t_start:.2f}")
print(f"Size of deduplicate dataset: {len(ds_filter)}")
# Save data in batches of samples_per_file
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True)
# save duplicate_clusters in the output_dir as artifacts
# not sure it is the right place the save it
if args.near_deduplication:
with open(output_dir / "duplicate_clusters.json", "w") as f:
json.dump(duplicate_clusters, f)
data_dir = output_dir / "data"
data_dir.mkdir(exist_ok=True)
t_start = time.time()
for file_number, index in enumerate(range(0, len(ds_filter), args.samples_per_file)):
file_path = f"{args.output_dir}/file-{file_number+1:012}.json"
file_path = str(data_dir / f"file-{file_number+1:012}.json")
end_index = min(len(ds_filter), index + args.samples_per_file)
ds_filter.select(list(range(index, end_index))).to_json(file_path)
compress_file(file_path)