From 1d651868d64e8f54f7bf6b687fbcdac832039334 Mon Sep 17 00:00:00 2001 From: Leandro von Werra Date: Thu, 23 Dec 2021 14:59:11 +0100 Subject: [PATCH] add custom stopping criteria to human eval script (#14897) --- .../codeparrot/requirements.txt | 2 +- .../codeparrot/scripts/human_eval.py | 44 ++++++++++++++++--- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/examples/research_projects/codeparrot/requirements.txt b/examples/research_projects/codeparrot/requirements.txt index 3333e3dc37..14b27c8e25 100644 --- a/examples/research_projects/codeparrot/requirements.txt +++ b/examples/research_projects/codeparrot/requirements.txt @@ -1,4 +1,4 @@ -transformers==4.12.2 +transformers==4.15.0 datasets==1.16.0 accelerate==0.5.1 wandb==0.12.0 diff --git a/examples/research_projects/codeparrot/scripts/human_eval.py b/examples/research_projects/codeparrot/scripts/human_eval.py index d70655b996..ea9eb82146 100644 --- a/examples/research_projects/codeparrot/scripts/human_eval.py +++ b/examples/research_projects/codeparrot/scripts/human_eval.py @@ -8,12 +8,40 @@ from tqdm import tqdm import transformers from arguments import HumanEvalArguments -from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, pipeline, set_seed +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + HfArgumentParser, + StoppingCriteria, + StoppingCriteriaList, + pipeline, + set_seed, +) + + +EOF_STRINGS = ["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif"] + + +class EndOfFunctionCriteria(StoppingCriteria): + """Custom `StoppingCriteria` which checks if all generated functions in the batch are completed.""" + + def __init__(self, start_length, eof_strings, tokenizer): + self.start_length = start_length + self.eof_strings = eof_strings + self.tokenizer = tokenizer + + def __call__(self, input_ids, scores, **kwargs): + """Returns true if all generated sequences contain any of the end-of-function strings.""" + decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :]) + done = [] + for decoded_generation in decoded_generations: + done.append(any([stop_string in decoded_generation for stop_string in self.eof_strings])) + return all(done) def first_block(string): """Split off first block of code by scanning for class, def etc. on newlines.""" - return re.split("\nclass|\ndef|\n#|\n@|\nprint|\nif", string)[0].rstrip() + return re.split("|".join(EOF_STRINGS), string)[0].rstrip() def complete_code(pipe, prompt, num_completions=1, **gen_kwargs): @@ -39,6 +67,11 @@ def main(): set_seed(args.seed) + # Load model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt) + model = AutoModelForCausalLM.from_pretrained(args.model_ckpt) + pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=args.device_int) + # Generation settings gen_kwargs = { "do_sample": args.do_sample, @@ -46,13 +79,9 @@ def main(): "max_new_tokens": args.max_new_tokens, "top_p": args.top_p, "top_k": args.top_k, + "stopping_criteria": StoppingCriteriaList([EndOfFunctionCriteria(0, EOF_STRINGS, tokenizer)]), } - # Load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt) - model = AutoModelForCausalLM.from_pretrained(args.model_ckpt) - pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=args.device_int) - # Load evaluation dataset and metric human_eval = load_dataset("openai_humaneval") code_eval_metric = load_metric("code_eval") @@ -72,6 +101,7 @@ def main(): for task in tqdm(range(n_tasks)): task_generations = [] prompt = human_eval["test"][task]["prompt"].strip() + gen_kwargs["stopping_criteria"][0].start_length = len(tokenizer(prompt)["input_ids"]) for batch in range(args.n_samples // args.batch_size): task_generations.extend(complete_code(pipe, prompt, num_completions=args.batch_size, **gen_kwargs)) generations.append([prompt + gen for gen in task_generations])