diff --git a/examples/research_projects/codeparrot/README.md b/examples/research_projects/codeparrot/README.md index 761b77a6df..1eaf063d8e 100644 --- a/examples/research_projects/codeparrot/README.md +++ b/examples/research_projects/codeparrot/README.md @@ -40,6 +40,7 @@ The source of the dataset is the GitHub dump available on Google's [BigQuery](ht The raw dataset contains many duplicates. We deduplicated and filtered the dataset using the heuristics proposed in OpenAI's Codex [paper](https://arxiv.org/abs/2107.03374) and some new ones: - exact deduplication using each file's hash +- near deduplication using MinHash and Jaccard similarity. MinHash with a Jaccard threshold (default=0.85) is first used to create duplicate clusters. Then these clusters are then reduced to unique files based on the exact Jaccard similarity. See `deduplicate_dataset` in `minhash_deduplication.py` for a detailed description. - filtering files with max line length > 1000 - filtering files with mean line length > 100 - fraction of alphanumeric characters < 0.25 diff --git a/examples/research_projects/codeparrot/requirements.txt b/examples/research_projects/codeparrot/requirements.txt index 267bcb9cb0..7eff3ac7f1 100644 --- a/examples/research_projects/codeparrot/requirements.txt +++ b/examples/research_projects/codeparrot/requirements.txt @@ -4,4 +4,6 @@ wandb==0.12.0 tensorboard==2.6.0 torch==1.11.0 huggingface-hub==0.1.0 -git+https://github.com/huggingface/accelerate.git@3c45b6f760ad8745be9ebc9bbb26f5b04dea4abe \ No newline at end of file +git+https://github.com/huggingface/accelerate.git@3c45b6f760ad8745be9ebc9bbb26f5b04dea4abe +datasketch==1.5.7 +dpu_utils \ No newline at end of file diff --git a/examples/research_projects/codeparrot/scripts/arguments.py b/examples/research_projects/codeparrot/scripts/arguments.py index 03d578cbb8..3c559c511e 100644 --- a/examples/research_projects/codeparrot/scripts/arguments.py +++ b/examples/research_projects/codeparrot/scripts/arguments.py @@ -157,6 +157,12 @@ class PreprocessingArguments: default="lvwerra/codeparrot", metadata={"help": "Name or path to the tokenizer."}, ) + near_deduplication: Optional[bool] = field( + default=False, metadata={"help": "If True, near-duplicate samples are removed."} + ) + jaccard_threshold: Optional[float] = field( + default=0.85, metadata={"help": "Jaccard threshold for near-duplicate samples."} + ) @dataclass diff --git a/examples/research_projects/codeparrot/scripts/minhash_deduplication.py b/examples/research_projects/codeparrot/scripts/minhash_deduplication.py new file mode 100644 index 0000000000..cd72dcb70c --- /dev/null +++ b/examples/research_projects/codeparrot/scripts/minhash_deduplication.py @@ -0,0 +1,270 @@ +import json +import multiprocessing as mp +import re +from collections import defaultdict +from functools import partial +from typing import Dict, List, Optional, Set, Tuple, Type + +from datasets import Dataset +from tqdm import tqdm + +from datasketch import MinHash, MinHashLSH +from dpu_utils.utils.iterators import ThreadedIterator + + +NON_ALPHA = re.compile("[^A-Za-z_0-9]") +# parameters used in DuplicationIndex +MIN_NUM_TOKENS = 10 +NUM_PERM = 256 + + +def get_min_hash(tokens: List[str]) -> Optional[MinHash]: + """Compute the MinHash of a code snippet.""" + if len(tokens) < MIN_NUM_TOKENS: + return None + min_hash = MinHash(num_perm=NUM_PERM) + for token in set(tokens): + min_hash.update(token.encode()) + return min_hash + + +def get_tokens(code: str) -> Set[str]: + """Tokenize a code snippet.""" + return set([t for t in NON_ALPHA.split(code) if len(t.strip()) > 0]) + + +class DuplicationIndex: + def __init__( + self, + *, + duplication_jaccard_threshold: float = 0.85, + ): + self._duplication_jaccard_threshold = duplication_jaccard_threshold + self._num_perm = NUM_PERM + self._index = MinHashLSH(threshold=self._duplication_jaccard_threshold, num_perm=self._num_perm) + + self._duplicate_clusters = defaultdict(set) + + def add(self, code_key: Tuple, min_hash: MinHash) -> None: + """Add a key to _index (MinHashLSH) + the min_hash is used to query closest matches based on the jaccard_threshold. + The new key is either added to a existing cluster of one close match, + or a new cluster is created. The clusters created in this way, depend on the order of add. + + Args: + code_key (Tuple of (index, repo_name, path)): + Theoritically any hasbale key. Here we use a tuple to retrieve the information later. + min_hash: MinHash of the code_key. + """ + close_duplicates = self._index.query(min_hash) + if code_key in self._index.keys: + print(f"Duplicate key {code_key}") + return + + self._index.insert(code_key, min_hash) + if len(close_duplicates) > 0: + + for base_duplicate in close_duplicates: + if base_duplicate in self._duplicate_clusters: + self._duplicate_clusters[base_duplicate].add(code_key) + break + else: + self._duplicate_clusters[close_duplicates[0]].add(code_key) + + def get_duplicate_clusters(self) -> List[List[Dict]]: + """Export the duplicate clusters. + For each cluster, the first element is the base element of the cluster. + The base element has an estimation jaccard similarity higher than the threshold with all the other elements. + + Returns: + duplicate_clusters (List[List[Dict]]): + List of duplicate clusters. + """ + duplicate_clusters = [] + for base, duplicates in self._duplicate_clusters.items(): + cluster = [base] + list(duplicates) + # reformat the cluster to be a list of dict + cluster = [{"base_index": el[0], "repo_name": el[1], "path": el[2]} for el in cluster] + duplicate_clusters.append(cluster) + return duplicate_clusters + + def save(self, filepath) -> None: + duplicate_clusters = self.get_duplicate_clusters() + with open(filepath, "w") as f: + json.dump(duplicate_clusters, f) + + +def _compute_min_hash(element): + index, data = element + min_hash = get_min_hash([t for t in NON_ALPHA.split(data["content"]) if len(t.strip()) > 0]) + if min_hash is not None: + return (index, data["repo_name"], data["path"]), min_hash + + +def minhash_iter(dataset_iterator: Type[Dataset]): + with mp.Pool() as pool: + for data in pool.imap_unordered( + _compute_min_hash, + ThreadedIterator(dataset_iterator, max_queue_size=10000), + chunksize=100, + ): + if data is not None: + yield data + + +def make_duplicate_clusters(dataset_iterator: Type[Dataset], jaccard_threshold: float): + """Find duplicate clusters in the dataset in two steps: + 1. Compute MinHash for each code snippet. MinHash is a tool for fast jaccard similarity estimation. + This step is computed using an asynchronous multiprocessing pool, minhash_iter + 2. Find duplicate clusters. The computed MinHash is added sequentially to the DuplicationIndex. + This step cannot be parallelized. So using asynchronous thread in the previous step helps to speed up the process. + """ + di = DuplicationIndex(duplication_jaccard_threshold=jaccard_threshold) + + for filename, min_hash in tqdm(ThreadedIterator(minhash_iter(enumerate(dataset_iterator)), max_queue_size=100)): + di.add(filename, min_hash) + + # Returns a List[Cluster] where Cluster is List[str] with the filenames. + return di.get_duplicate_clusters() + + +def jaccard_similarity(code1: str, code2: str) -> float: + """Compute the Jaccard similarity of two code snippets.""" + tokens1 = get_tokens(code1) + tokens2 = get_tokens(code2) + return len(tokens1 & tokens2) / len(tokens1 | tokens2) + + +_shared_dataset = None + + +def _find_cluster_extremes_shared(cluster, jaccard_threshold): + """Find a reduced cluster such that each code in the origin cluster is similar to at least one code in the reduced cluster. + Two codes are similar if their Jaccard similarity is above the threshold. + + Args: + cluster (List[dict]): + cluster is a list of dict, each dict contains the following keys: + - base_index + - repo_name + - path + This is a typical output of DuplicationIndex.get_duplicate_clusters() + jaccard_threshold (float): + threshold for Jaccard similarity. + Two codes are similar if their Jaccard similarity is above the threshold. + + Returns: + extremes (List[dict]): + A reduced representation of the cluster. The field copies is added to each dict. + The copies field indicates the number of similar codes in the cluster for a extreme. + """ + extremes = [] + for element1 in cluster: + code1 = _shared_dataset[element1["base_index"]]["content"] + for element2 in extremes: + code2 = _shared_dataset[element2["base_index"]]["content"] + if jaccard_similarity(code1, code2) >= jaccard_threshold: + element2["copies"] += 1 + break + else: + element1["copies"] = 1 + extremes.append(element1) + return extremes + + +def find_extremes(cluster_list, dataset, jaccard_threshold): + """Call the _find_cluster_extremes_shared function in a parallel fashion. + + Args: + cluster_list (List[List[Dict]]): + each cluster is a list of dicts with the key base_index, + referring to the index of the base code in the dataset. + dataset (Type[Dataset]): + dataset is used to access the content of the code snippets, + using the base_index from the cluster_list. + dataset is shared between all the processes using a glabal variable (any other way to share the dataset?), + otherwise the multi processing is not speeded up. + jaccard_threshold (float): + the threshold for the jaccard similarity. The default value is 0.85 + + Returns: + extremes_list (List[Dict]): + Each cluster is reduced to extremes. + See _find_cluster_extremes_shared for the definition of extremes. + """ + global _shared_dataset + _shared_dataset = dataset + extremes_list = [] + f = partial(_find_cluster_extremes_shared, jaccard_threshold=jaccard_threshold) + with mp.Pool() as pool: + for extremes in tqdm( + pool.imap_unordered( + f, + cluster_list, + ), + total=len(cluster_list), + ): + extremes_list.append(extremes) + return extremes_list + + +def deduplicate_dataset( + dataset: Type[Dataset], jaccard_threshold: float = 0.85 +) -> Tuple[Type[Dataset], List[List[Dict]]]: + """Deduplicate the dataset using minhash and jaccard similarity. + This function first generate duplicate clusters, then each cluster + is reduced to the extremes that are similar to the other elements in the cluster. + Codes are called similar if their Jaccard similarity is greater than jaccard_threshold (0.85 default). + + Args: + dataset (Type[Dataset]): + The dataset to deduplicate. + jaccard_threshold (float, default=0.85): + jaccard threshold to determine if two codes are similar + + Returns: + ds_dedup (Type[Dataset]): + The deduplicated dataset. + duplicate_clusters (List[List[Dict]]): + The list of duplicate clusters. + Each cluster is a list of dicts with the following keys: + - base_index : int + The index of the code in the original dataset. + - repo_name : str + - path : str + - copies : int + The number of copies of the code in the cluster. (find_cluster_extremes) + - is_extreme : bool + Whether the code is an extreme in the cluster. + All the codes in the cluster are removed from the dataset except the extremes. + + Example: + >>> from datasets import load_dataset + >>> from minhash_deduplication import deduplicate_dataset + >>> ds = load_dataset("lvwerra/codeparrot-clean", split="train") + >>> ds_dedup, duplicate_clusters = deduplicate_dataset(ds, jaccard_threshold=0.85) + """ + duplicate_clusters = make_duplicate_clusters(dataset, jaccard_threshold) + duplicate_indices = set(x["base_index"] for cluster in duplicate_clusters for x in cluster) + extreme_dict = {} + extremes_clusters = find_extremes(duplicate_clusters, dataset, jaccard_threshold) + for extremes in extremes_clusters: + for element in extremes: + extreme_dict[element["base_index"]] = element + remove_indices = duplicate_indices - set(extreme_dict.keys()) + ds_filter = dataset.filter(lambda x, idx: idx not in remove_indices, with_indices=True) + + # update duplicate_clusters + for cluster in duplicate_clusters: + for element in cluster: + element["is_extreme"] = element["base_index"] in extreme_dict + if element["is_extreme"]: + element["copies"] = extreme_dict[element["base_index"]]["copies"] + + print(f"Original dataset size: {len(dataset)}") + print(f"Number of duplicate clusters: {len(duplicate_clusters)}") + print(f"Files in duplicate cluster: {len(duplicate_indices)}") + print(f"Unique files in duplicate cluster: {len(extreme_dict)}") + print(f"Filtered dataset size: {len(ds_filter)}") + + return ds_filter, duplicate_clusters diff --git a/examples/research_projects/codeparrot/scripts/preprocessing.py b/examples/research_projects/codeparrot/scripts/preprocessing.py index 0e5899f5de..3d4ec40dec 100644 --- a/examples/research_projects/codeparrot/scripts/preprocessing.py +++ b/examples/research_projects/codeparrot/scripts/preprocessing.py @@ -1,14 +1,17 @@ import gzip import hashlib +import json import multiprocessing import os import shutil import time +from pathlib import Path import numpy as np from datasets import load_dataset from arguments import PreprocessingArguments +from minhash_deduplication import deduplicate_dataset from transformers import AutoTokenizer, HfArgumentParser @@ -146,7 +149,7 @@ def filter(example, uniques, args): def compress_file(file_path): """Compress a file with g-zip.""" with open(file_path, "rb") as f_in: - with gzip.open(file_path + ".gz", "wb", compresslevel=6) as f_out: + with gzip.open(str(file_path) + ".gz", "wb", compresslevel=6) as f_out: shutil.copyfileobj(f_in, f_out) os.unlink(file_path) @@ -179,12 +182,29 @@ ds_filter = ds.filter(filter, fn_kwargs={"uniques": uniques, "args": args}) print(f"Time to filter dataset: {time.time()-t_start:.2f}") print(f"Size of filtered dataset: {len(ds_filter)}") +# Deduplicate with minhash and jaccard similarity +if args.near_deduplication: + t_start = time.time() + ds_filter, duplicate_clusters = deduplicate_dataset(ds_filter, args.jaccard_threshold) + print(f"Time to deduplicate dataset: {time.time()-t_start:.2f}") + print(f"Size of deduplicate dataset: {len(ds_filter)}") + # Save data in batches of samples_per_file -if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) +output_dir = Path(args.output_dir) +output_dir.mkdir(exist_ok=True) + +# save duplicate_clusters in the output_dir as artifacts +# not sure it is the right place the save it +if args.near_deduplication: + with open(output_dir / "duplicate_clusters.json", "w") as f: + json.dump(duplicate_clusters, f) + +data_dir = output_dir / "data" +data_dir.mkdir(exist_ok=True) + t_start = time.time() for file_number, index in enumerate(range(0, len(ds_filter), args.samples_per_file)): - file_path = f"{args.output_dir}/file-{file_number+1:012}.json" + file_path = str(data_dir / f"file-{file_number+1:012}.json") end_index = min(len(ds_filter), index + args.samples_per_file) ds_filter.select(list(range(index, end_index))).to_json(file_path) compress_file(file_path) diff --git a/examples/research_projects/codeparrot/scripts/tests/__init__.py b/examples/research_projects/codeparrot/scripts/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/research_projects/codeparrot/scripts/tests/test_deduplicate.py b/examples/research_projects/codeparrot/scripts/tests/test_deduplicate.py new file mode 100644 index 0000000000..e443827135 --- /dev/null +++ b/examples/research_projects/codeparrot/scripts/tests/test_deduplicate.py @@ -0,0 +1,30 @@ +from unittest import TestCase + +from datasets import Dataset + +from minhash_deduplication import deduplicate_dataset, make_duplicate_clusters + + +def get_dataset(): + data_dict = { + "repo_name": ["test_repo1", "test_repo2", "test_repo3"], + "path": ["test_1.py", "test_2.py", "unit_test.py"], + "content": ["a " * 20, "a " * 30, "b " * 7], + } + dataset = Dataset.from_dict(data_dict) + return dataset + + +class MakeDuplicateClustersTest(TestCase): + def test_make_duplicate_clusters(self): + ds = get_dataset() + duplicate_clusters = make_duplicate_clusters(ds, 0.85) + self.assertEqual(len(duplicate_clusters[0]), 2) + + def test_deduplicate_dataset(self): + ds = get_dataset() + ds_filter, duplicate_clusters = deduplicate_dataset(ds) + self.assertEqual(len(ds_filter), 2) + print(duplicate_clusters) + self.assertEqual(duplicate_clusters[0][0]["copies"], 2) + self.assertEqual(duplicate_clusters[0][0]["is_extreme"], True)