add custom stopping criteria to human eval script (#14897)
This commit is contained in:
committed by
GitHub
parent
6b655cc63f
commit
1d651868d6
@@ -1,4 +1,4 @@
|
|||||||
transformers==4.12.2
|
transformers==4.15.0
|
||||||
datasets==1.16.0
|
datasets==1.16.0
|
||||||
accelerate==0.5.1
|
accelerate==0.5.1
|
||||||
wandb==0.12.0
|
wandb==0.12.0
|
||||||
|
|||||||
@@ -8,12 +8,40 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from arguments import HumanEvalArguments
|
from arguments import HumanEvalArguments
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, pipeline, set_seed
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
HfArgumentParser,
|
||||||
|
StoppingCriteria,
|
||||||
|
StoppingCriteriaList,
|
||||||
|
pipeline,
|
||||||
|
set_seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
EOF_STRINGS = ["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif"]
|
||||||
|
|
||||||
|
|
||||||
|
class EndOfFunctionCriteria(StoppingCriteria):
|
||||||
|
"""Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
|
||||||
|
|
||||||
|
def __init__(self, start_length, eof_strings, tokenizer):
|
||||||
|
self.start_length = start_length
|
||||||
|
self.eof_strings = eof_strings
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __call__(self, input_ids, scores, **kwargs):
|
||||||
|
"""Returns true if all generated sequences contain any of the end-of-function strings."""
|
||||||
|
decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
|
||||||
|
done = []
|
||||||
|
for decoded_generation in decoded_generations:
|
||||||
|
done.append(any([stop_string in decoded_generation for stop_string in self.eof_strings]))
|
||||||
|
return all(done)
|
||||||
|
|
||||||
|
|
||||||
def first_block(string):
|
def first_block(string):
|
||||||
"""Split off first block of code by scanning for class, def etc. on newlines."""
|
"""Split off first block of code by scanning for class, def etc. on newlines."""
|
||||||
return re.split("\nclass|\ndef|\n#|\n@|\nprint|\nif", string)[0].rstrip()
|
return re.split("|".join(EOF_STRINGS), string)[0].rstrip()
|
||||||
|
|
||||||
|
|
||||||
def complete_code(pipe, prompt, num_completions=1, **gen_kwargs):
|
def complete_code(pipe, prompt, num_completions=1, **gen_kwargs):
|
||||||
@@ -39,6 +67,11 @@ def main():
|
|||||||
|
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
|
# Load model and tokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
|
||||||
|
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=args.device_int)
|
||||||
|
|
||||||
# Generation settings
|
# Generation settings
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
"do_sample": args.do_sample,
|
"do_sample": args.do_sample,
|
||||||
@@ -46,13 +79,9 @@ def main():
|
|||||||
"max_new_tokens": args.max_new_tokens,
|
"max_new_tokens": args.max_new_tokens,
|
||||||
"top_p": args.top_p,
|
"top_p": args.top_p,
|
||||||
"top_k": args.top_k,
|
"top_k": args.top_k,
|
||||||
|
"stopping_criteria": StoppingCriteriaList([EndOfFunctionCriteria(0, EOF_STRINGS, tokenizer)]),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Load model and tokenizer
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
|
|
||||||
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=args.device_int)
|
|
||||||
|
|
||||||
# Load evaluation dataset and metric
|
# Load evaluation dataset and metric
|
||||||
human_eval = load_dataset("openai_humaneval")
|
human_eval = load_dataset("openai_humaneval")
|
||||||
code_eval_metric = load_metric("code_eval")
|
code_eval_metric = load_metric("code_eval")
|
||||||
@@ -72,6 +101,7 @@ def main():
|
|||||||
for task in tqdm(range(n_tasks)):
|
for task in tqdm(range(n_tasks)):
|
||||||
task_generations = []
|
task_generations = []
|
||||||
prompt = human_eval["test"][task]["prompt"].strip()
|
prompt = human_eval["test"][task]["prompt"].strip()
|
||||||
|
gen_kwargs["stopping_criteria"][0].start_length = len(tokenizer(prompt)["input_ids"])
|
||||||
for batch in range(args.n_samples // args.batch_size):
|
for batch in range(args.n_samples // args.batch_size):
|
||||||
task_generations.extend(complete_code(pipe, prompt, num_completions=args.batch_size, **gen_kwargs))
|
task_generations.extend(complete_code(pipe, prompt, num_completions=args.batch_size, **gen_kwargs))
|
||||||
generations.append([prompt + gen for gen in task_generations])
|
generations.append([prompt + gen for gen in task_generations])
|
||||||
|
|||||||
Reference in New Issue
Block a user