Agents use grammar (#31735)
* Allow optional use of grammars to constrain generation
This commit is contained in:
@@ -119,10 +119,12 @@ def llm_engine(messages, stop_sequences=["Task"]) -> str:
|
|||||||
```
|
```
|
||||||
|
|
||||||
You could use any `llm_engine` method as long as:
|
You could use any `llm_engine` method as long as:
|
||||||
1. it follows the [messages format](./chat_templating.md) for its input (`List[Dict[str, str]]`) and returns a `str`
|
1. it follows the [messages format](./chat_templating.md) (`List[Dict[str, str]]`) for its input `messages`, and it returns a `str`.
|
||||||
2. it stops generating outputs at the sequences passed in the argument `stop`
|
2. it stops generating outputs at the sequences passed in the argument `stop_sequences`
|
||||||
|
|
||||||
You also need a `tools` argument which accepts a list of `Tools`. You can provide an empty list for `tools`, but use the default toolbox with the optional argument `add_base_tools=True`.
|
Additionally, `llm_engine` can also take a `grammar` argument. In the case where you specify a `grammar` upon agent initialization, this argument will be passed to the calls to llm_engine, with the `grammar` that you defined upon initialization, to allow [constrained generation](https://huggingface.co/docs/text-generation-inference/conceptual/guidance) in order to force properly-formatted agent outputs.
|
||||||
|
|
||||||
|
You will also need a `tools` argument which accepts a list of `Tools` - it can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`.
|
||||||
|
|
||||||
Now you can create an agent, like [`CodeAgent`], and run it. For convenience, we also provide the [`HfEngine`] class that uses `huggingface_hub.InferenceClient` under the hood.
|
Now you can create an agent, like [`CodeAgent`], and run it. For convenience, we also provide the [`HfEngine`] class that uses `huggingface_hub.InferenceClient` under the hood.
|
||||||
|
|
||||||
|
|||||||
@@ -328,7 +328,7 @@ class Agent:
|
|||||||
self,
|
self,
|
||||||
tools: Union[List[Tool], Toolbox],
|
tools: Union[List[Tool], Toolbox],
|
||||||
llm_engine: Callable = HfEngine(),
|
llm_engine: Callable = HfEngine(),
|
||||||
system_prompt=DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
system_prompt=DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||||
tool_description_template=None,
|
tool_description_template=None,
|
||||||
additional_args={},
|
additional_args={},
|
||||||
max_iterations: int = 6,
|
max_iterations: int = 6,
|
||||||
@@ -336,6 +336,7 @@ class Agent:
|
|||||||
add_base_tools: bool = False,
|
add_base_tools: bool = False,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
memory_verbose: bool = False,
|
memory_verbose: bool = False,
|
||||||
|
grammar: Dict[str, str] = None,
|
||||||
):
|
):
|
||||||
self.agent_name = self.__class__.__name__
|
self.agent_name = self.__class__.__name__
|
||||||
self.llm_engine = llm_engine
|
self.llm_engine = llm_engine
|
||||||
@@ -347,6 +348,7 @@ class Agent:
|
|||||||
self.max_iterations = max_iterations
|
self.max_iterations = max_iterations
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.tool_parser = tool_parser
|
self.tool_parser = tool_parser
|
||||||
|
self.grammar = grammar
|
||||||
|
|
||||||
if isinstance(tools, Toolbox):
|
if isinstance(tools, Toolbox):
|
||||||
self._toolbox = tools
|
self._toolbox = tools
|
||||||
@@ -533,6 +535,7 @@ class CodeAgent(Agent):
|
|||||||
llm_engine: Callable = HfEngine(),
|
llm_engine: Callable = HfEngine(),
|
||||||
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
|
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
|
||||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
|
grammar: Dict[str, str] = None,
|
||||||
additional_authorized_imports: Optional[List[str]] = None,
|
additional_authorized_imports: Optional[List[str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -541,6 +544,7 @@ class CodeAgent(Agent):
|
|||||||
llm_engine=llm_engine,
|
llm_engine=llm_engine,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
tool_description_template=tool_description_template,
|
tool_description_template=tool_description_template,
|
||||||
|
grammar=grammar,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -599,7 +603,9 @@ class CodeAgent(Agent):
|
|||||||
self.prompt = [prompt_message, task_message]
|
self.prompt = [prompt_message, task_message]
|
||||||
self.logger.info("====Executing with this prompt====")
|
self.logger.info("====Executing with this prompt====")
|
||||||
self.logger.info(self.prompt)
|
self.logger.info(self.prompt)
|
||||||
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"])
|
|
||||||
|
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||||
|
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"], **additional_args)
|
||||||
|
|
||||||
if return_generated_code:
|
if return_generated_code:
|
||||||
return llm_output
|
return llm_output
|
||||||
@@ -652,6 +658,7 @@ class ReactAgent(Agent):
|
|||||||
llm_engine: Callable = HfEngine(),
|
llm_engine: Callable = HfEngine(),
|
||||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
|
grammar: Dict[str, str] = None,
|
||||||
plan_type: Literal[tuple(SUPPORTED_PLAN_TYPES)] = SUPPORTED_PLAN_TYPES[0],
|
plan_type: Literal[tuple(SUPPORTED_PLAN_TYPES)] = SUPPORTED_PLAN_TYPES[0],
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -662,6 +669,7 @@ class ReactAgent(Agent):
|
|||||||
llm_engine=llm_engine,
|
llm_engine=llm_engine,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
tool_description_template=tool_description_template,
|
tool_description_template=tool_description_template,
|
||||||
|
grammar=grammar,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.planning_interval = planning_interval
|
self.planning_interval = planning_interval
|
||||||
@@ -881,6 +889,7 @@ class ReactJsonAgent(ReactAgent):
|
|||||||
llm_engine: Callable = HfEngine(),
|
llm_engine: Callable = HfEngine(),
|
||||||
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
||||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
|
grammar: Dict[str, str] = None,
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -889,6 +898,7 @@ class ReactJsonAgent(ReactAgent):
|
|||||||
llm_engine=llm_engine,
|
llm_engine=llm_engine,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
tool_description_template=tool_description_template,
|
tool_description_template=tool_description_template,
|
||||||
|
grammar=grammar,
|
||||||
planning_interval=planning_interval,
|
planning_interval=planning_interval,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@@ -912,7 +922,10 @@ class ReactJsonAgent(ReactAgent):
|
|||||||
self.logger.info(self.prompt[-1])
|
self.logger.info(self.prompt[-1])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>", "Observation:"])
|
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||||
|
llm_output = self.llm_engine(
|
||||||
|
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||||
self.logger.debug("===== Output message of the LLM: =====")
|
self.logger.debug("===== Output message of the LLM: =====")
|
||||||
@@ -982,6 +995,7 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
llm_engine: Callable = HfEngine(),
|
llm_engine: Callable = HfEngine(),
|
||||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
|
grammar: Dict[str, str] = None,
|
||||||
additional_authorized_imports: Optional[List[str]] = None,
|
additional_authorized_imports: Optional[List[str]] = None,
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -991,6 +1005,7 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
llm_engine=llm_engine,
|
llm_engine=llm_engine,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
tool_description_template=tool_description_template,
|
tool_description_template=tool_description_template,
|
||||||
|
grammar=grammar,
|
||||||
planning_interval=planning_interval,
|
planning_interval=planning_interval,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@@ -1028,7 +1043,10 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
self.logger.info(self.prompt[-2:])
|
self.logger.info(self.prompt[-2:])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>", "Observation:"])
|
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||||
|
llm_output = self.llm_engine(
|
||||||
|
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from huggingface_hub import InferenceClient
|
from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
@@ -66,16 +66,24 @@ llama_role_conversions = {
|
|||||||
|
|
||||||
|
|
||||||
class HfEngine:
|
class HfEngine:
|
||||||
def __init__(self, model: str = "meta-llama/Meta-Llama-3-8B-Instruct"):
|
def __init__(self, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.client = InferenceClient(model=self.model, timeout=120)
|
self.client = InferenceClient(self.model, timeout=120)
|
||||||
|
|
||||||
def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str:
|
def __call__(
|
||||||
|
self, messages: List[Dict[str, str]], stop_sequences: List[str] = [], grammar: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
# Get clean message list
|
# Get clean message list
|
||||||
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
||||||
|
|
||||||
# Get LLM output
|
# Get LLM output
|
||||||
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500)
|
if grammar is not None:
|
||||||
|
response = self.client.chat_completion(
|
||||||
|
messages, stop=stop_sequences, max_tokens=1500, response_format=grammar
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500)
|
||||||
|
|
||||||
response = response.choices[0].message.content
|
response = response.choices[0].message.content
|
||||||
|
|
||||||
# Remove stop sequences from LLM output
|
# Remove stop sequences from LLM output
|
||||||
@@ -83,3 +91,14 @@ class HfEngine:
|
|||||||
if response[-len(stop_seq) :] == stop_seq:
|
if response[-len(stop_seq) :] == stop_seq:
|
||||||
response = response[: -len(stop_seq)]
|
response = response[: -len(stop_seq)]
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
||||||
|
"type": "regex",
|
||||||
|
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
|
||||||
|
"type": "regex",
|
||||||
|
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
|
||||||
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ Examples:
|
|||||||
---
|
---
|
||||||
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
|
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
|
||||||
|
|
||||||
I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
|
Thought: I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||||
@@ -75,7 +75,7 @@ final_answer(f"The answer is {answer}")
|
|||||||
---
|
---
|
||||||
Task: "Identify the oldest person in the `document` and create an image showcasing the result."
|
Task: "Identify the oldest person in the `document` and create an image showcasing the result."
|
||||||
|
|
||||||
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
Thought: I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
answer = document_qa(document, question="What is the oldest person?")
|
answer = document_qa(document, question="What is the oldest person?")
|
||||||
@@ -87,7 +87,7 @@ final_answer(image)
|
|||||||
---
|
---
|
||||||
Task: "Generate an image using the text given in the variable `caption`."
|
Task: "Generate an image using the text given in the variable `caption`."
|
||||||
|
|
||||||
I will use the following tool: `image_generator` to generate an image.
|
Thought: I will use the following tool: `image_generator` to generate an image.
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
image = image_generator(prompt=caption)
|
image = image_generator(prompt=caption)
|
||||||
@@ -97,7 +97,7 @@ final_answer(image)
|
|||||||
---
|
---
|
||||||
Task: "Summarize the text given in the variable `text` and read it out loud."
|
Task: "Summarize the text given in the variable `text` and read it out loud."
|
||||||
|
|
||||||
I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud.
|
Thought: I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud.
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
summarized_text = summarizer(text)
|
summarized_text = summarizer(text)
|
||||||
@@ -109,7 +109,7 @@ final_answer(audio_summary)
|
|||||||
---
|
---
|
||||||
Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
|
Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
|
||||||
|
|
||||||
I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer.
|
Thought: I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer.
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
answer = text_qa(text=text, question=question)
|
answer = text_qa(text=text, question=question)
|
||||||
@@ -121,7 +121,7 @@ final_answer(image)
|
|||||||
---
|
---
|
||||||
Task: "Caption the following `image`."
|
Task: "Caption the following `image`."
|
||||||
|
|
||||||
I will use the following tool: `image_captioner` to generate a caption for the image.
|
Thought: I will use the following tool: `image_captioner` to generate a caption for the image.
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
caption = image_captioner(image)
|
caption = image_captioner(image)
|
||||||
@@ -292,7 +292,6 @@ print(answer)
|
|||||||
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
||||||
|
|
||||||
Thought: I will now generate an image showcasing the oldest person.
|
Thought: I will now generate an image showcasing the oldest person.
|
||||||
|
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.")
|
image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.")
|
||||||
@@ -303,7 +302,6 @@ final_answer(image)
|
|||||||
Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
||||||
|
|
||||||
Thought: I will use python code to compute the result of the operation and then return the final answer using the `final_answer` tool
|
Thought: I will use python code to compute the result of the operation and then return the final answer using the `final_answer` tool
|
||||||
|
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
result = 5 + 3 + 1294.678
|
result = 5 + 3 + 1294.678
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ def get_new_path(suffix="") -> str:
|
|||||||
return os.path.join(directory, str(uuid.uuid4()) + suffix)
|
return os.path.join(directory, str(uuid.uuid4()) + suffix)
|
||||||
|
|
||||||
|
|
||||||
def fake_react_json_llm(messages, stop_sequences=None) -> str:
|
def fake_react_json_llm(messages, stop_sequences=None, grammar=None) -> str:
|
||||||
prompt = str(messages)
|
prompt = str(messages)
|
||||||
|
|
||||||
if "special_marker" not in prompt:
|
if "special_marker" not in prompt:
|
||||||
@@ -53,7 +53,7 @@ Action:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def fake_react_code_llm(messages, stop_sequences=None) -> str:
|
def fake_react_code_llm(messages, stop_sequences=None, grammar=None) -> str:
|
||||||
prompt = str(messages)
|
prompt = str(messages)
|
||||||
if "special_marker" not in prompt:
|
if "special_marker" not in prompt:
|
||||||
return """
|
return """
|
||||||
@@ -119,7 +119,7 @@ final_answer(res)
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
|
def fake_code_llm_oneshot(messages, stop_sequences=None, grammar=None) -> str:
|
||||||
return """
|
return """
|
||||||
Thought: I should multiply 2 by 3.6452. special_marker
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
Code:
|
Code:
|
||||||
@@ -130,7 +130,7 @@ final_answer(result)
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def fake_code_llm_no_return(messages, stop_sequences=None) -> str:
|
def fake_code_llm_no_return(messages, stop_sequences=None, grammar=None) -> str:
|
||||||
return """
|
return """
|
||||||
Thought: I should multiply 2 by 3.6452. special_marker
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
Code:
|
Code:
|
||||||
|
|||||||
Reference in New Issue
Block a user