diff --git a/docs/source/en/agents.md b/docs/source/en/agents.md index f335cb678f..67c4b8a91b 100644 --- a/docs/source/en/agents.md +++ b/docs/source/en/agents.md @@ -119,10 +119,12 @@ def llm_engine(messages, stop_sequences=["Task"]) -> str: ``` You could use any `llm_engine` method as long as: -1. it follows the [messages format](./chat_templating.md) for its input (`List[Dict[str, str]]`) and returns a `str` -2. it stops generating outputs at the sequences passed in the argument `stop` +1. it follows the [messages format](./chat_templating.md) (`List[Dict[str, str]]`) for its input `messages`, and it returns a `str`. +2. it stops generating outputs at the sequences passed in the argument `stop_sequences` -You also need a `tools` argument which accepts a list of `Tools`. You can provide an empty list for `tools`, but use the default toolbox with the optional argument `add_base_tools=True`. +Additionally, `llm_engine` can also take a `grammar` argument. In the case where you specify a `grammar` upon agent initialization, this argument will be passed to the calls to llm_engine, with the `grammar` that you defined upon initialization, to allow [constrained generation](https://huggingface.co/docs/text-generation-inference/conceptual/guidance) in order to force properly-formatted agent outputs. + +You will also need a `tools` argument which accepts a list of `Tools` - it can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`. Now you can create an agent, like [`CodeAgent`], and run it. For convenience, we also provide the [`HfEngine`] class that uses `huggingface_hub.InferenceClient` under the hood. diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index fdcd23e8fc..2f2316817b 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -328,7 +328,7 @@ class Agent: self, tools: Union[List[Tool], Toolbox], llm_engine: Callable = HfEngine(), - system_prompt=DEFAULT_REACT_JSON_SYSTEM_PROMPT, + system_prompt=DEFAULT_REACT_CODE_SYSTEM_PROMPT, tool_description_template=None, additional_args={}, max_iterations: int = 6, @@ -336,6 +336,7 @@ class Agent: add_base_tools: bool = False, verbose: int = 0, memory_verbose: bool = False, + grammar: Dict[str, str] = None, ): self.agent_name = self.__class__.__name__ self.llm_engine = llm_engine @@ -347,6 +348,7 @@ class Agent: self.max_iterations = max_iterations self.logger = logger self.tool_parser = tool_parser + self.grammar = grammar if isinstance(tools, Toolbox): self._toolbox = tools @@ -533,6 +535,7 @@ class CodeAgent(Agent): llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, + grammar: Dict[str, str] = None, additional_authorized_imports: Optional[List[str]] = None, **kwargs, ): @@ -541,6 +544,7 @@ class CodeAgent(Agent): llm_engine=llm_engine, system_prompt=system_prompt, tool_description_template=tool_description_template, + grammar=grammar, **kwargs, ) @@ -599,7 +603,9 @@ class CodeAgent(Agent): self.prompt = [prompt_message, task_message] self.logger.info("====Executing with this prompt====") self.logger.info(self.prompt) - llm_output = self.llm_engine(self.prompt, stop_sequences=[""]) + + additional_args = {"grammar": self.grammar} if self.grammar is not None else {} + llm_output = self.llm_engine(self.prompt, stop_sequences=[""], **additional_args) if return_generated_code: return llm_output @@ -652,6 +658,7 @@ class ReactAgent(Agent): llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, + grammar: Dict[str, str] = None, plan_type: Literal[tuple(SUPPORTED_PLAN_TYPES)] = SUPPORTED_PLAN_TYPES[0], planning_interval: Optional[int] = None, **kwargs, @@ -662,6 +669,7 @@ class ReactAgent(Agent): llm_engine=llm_engine, system_prompt=system_prompt, tool_description_template=tool_description_template, + grammar=grammar, **kwargs, ) self.planning_interval = planning_interval @@ -881,6 +889,7 @@ class ReactJsonAgent(ReactAgent): llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, + grammar: Dict[str, str] = None, planning_interval: Optional[int] = None, **kwargs, ): @@ -889,6 +898,7 @@ class ReactJsonAgent(ReactAgent): llm_engine=llm_engine, system_prompt=system_prompt, tool_description_template=tool_description_template, + grammar=grammar, planning_interval=planning_interval, **kwargs, ) @@ -912,7 +922,10 @@ class ReactJsonAgent(ReactAgent): self.logger.info(self.prompt[-1]) try: - llm_output = self.llm_engine(self.prompt, stop_sequences=["", "Observation:"]) + additional_args = {"grammar": self.grammar} if self.grammar is not None else {} + llm_output = self.llm_engine( + self.prompt, stop_sequences=["", "Observation:"], **additional_args + ) except Exception as e: raise AgentGenerationError(f"Error in generating llm output: {e}.") self.logger.debug("===== Output message of the LLM: =====") @@ -982,6 +995,7 @@ class ReactCodeAgent(ReactAgent): llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, + grammar: Dict[str, str] = None, additional_authorized_imports: Optional[List[str]] = None, planning_interval: Optional[int] = None, **kwargs, @@ -991,6 +1005,7 @@ class ReactCodeAgent(ReactAgent): llm_engine=llm_engine, system_prompt=system_prompt, tool_description_template=tool_description_template, + grammar=grammar, planning_interval=planning_interval, **kwargs, ) @@ -1028,7 +1043,10 @@ class ReactCodeAgent(ReactAgent): self.logger.info(self.prompt[-2:]) try: - llm_output = self.llm_engine(self.prompt, stop_sequences=["", "Observation:"]) + additional_args = {"grammar": self.grammar} if self.grammar is not None else {} + llm_output = self.llm_engine( + self.prompt, stop_sequences=["", "Observation:"], **additional_args + ) except Exception as e: raise AgentGenerationError(f"Error in generating llm output: {e}.") diff --git a/src/transformers/agents/llm_engine.py b/src/transformers/agents/llm_engine.py index eb5edf7515..09d6176b1e 100644 --- a/src/transformers/agents/llm_engine.py +++ b/src/transformers/agents/llm_engine.py @@ -16,7 +16,7 @@ # limitations under the License. from copy import deepcopy from enum import Enum -from typing import Dict, List +from typing import Dict, List, Optional from huggingface_hub import InferenceClient @@ -66,16 +66,24 @@ llama_role_conversions = { class HfEngine: - def __init__(self, model: str = "meta-llama/Meta-Llama-3-8B-Instruct"): + def __init__(self, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"): self.model = model - self.client = InferenceClient(model=self.model, timeout=120) + self.client = InferenceClient(self.model, timeout=120) - def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str: + def __call__( + self, messages: List[Dict[str, str]], stop_sequences: List[str] = [], grammar: Optional[str] = None + ) -> str: # Get clean message list messages = get_clean_message_list(messages, role_conversions=llama_role_conversions) # Get LLM output - response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500) + if grammar is not None: + response = self.client.chat_completion( + messages, stop=stop_sequences, max_tokens=1500, response_format=grammar + ) + else: + response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500) + response = response.choices[0].message.content # Remove stop sequences from LLM output @@ -83,3 +91,14 @@ class HfEngine: if response[-len(stop_seq) :] == stop_seq: response = response[: -len(stop_seq)] return response + + +DEFAULT_JSONAGENT_REGEX_GRAMMAR = { + "type": "regex", + "value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n', +} + +DEFAULT_CODEAGENT_REGEX_GRAMMAR = { + "type": "regex", + "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```", +} diff --git a/src/transformers/agents/prompts.py b/src/transformers/agents/prompts.py index c94bb2af46..bbc674adc2 100644 --- a/src/transformers/agents/prompts.py +++ b/src/transformers/agents/prompts.py @@ -63,7 +63,7 @@ Examples: --- Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French." -I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image. +Thought: I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image. Code: ```py translated_question = translator(question=question, src_lang="French", tgt_lang="English") @@ -75,7 +75,7 @@ final_answer(f"The answer is {answer}") --- Task: "Identify the oldest person in the `document` and create an image showcasing the result." -I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer. +Thought: I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer. Code: ```py answer = document_qa(document, question="What is the oldest person?") @@ -87,7 +87,7 @@ final_answer(image) --- Task: "Generate an image using the text given in the variable `caption`." -I will use the following tool: `image_generator` to generate an image. +Thought: I will use the following tool: `image_generator` to generate an image. Code: ```py image = image_generator(prompt=caption) @@ -97,7 +97,7 @@ final_answer(image) --- Task: "Summarize the text given in the variable `text` and read it out loud." -I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud. +Thought: I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud. Code: ```py summarized_text = summarizer(text) @@ -109,7 +109,7 @@ final_answer(audio_summary) --- Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image." -I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer. +Thought: I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer. Code: ```py answer = text_qa(text=text, question=question) @@ -121,7 +121,7 @@ final_answer(image) --- Task: "Caption the following `image`." -I will use the following tool: `image_captioner` to generate a caption for the image. +Thought: I will use the following tool: `image_captioner` to generate a caption for the image. Code: ```py caption = image_captioner(image) @@ -292,7 +292,6 @@ print(answer) Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." Thought: I will now generate an image showcasing the oldest person. - Code: ```py image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.") @@ -303,7 +302,6 @@ final_answer(image) Task: "What is the result of the following operation: 5 + 3 + 1294.678?" Thought: I will use python code to compute the result of the operation and then return the final answer using the `final_answer` tool - Code: ```py result = 5 + 3 + 1294.678 diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 6dac8b8520..c18a568fdf 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -30,7 +30,7 @@ def get_new_path(suffix="") -> str: return os.path.join(directory, str(uuid.uuid4()) + suffix) -def fake_react_json_llm(messages, stop_sequences=None) -> str: +def fake_react_json_llm(messages, stop_sequences=None, grammar=None) -> str: prompt = str(messages) if "special_marker" not in prompt: @@ -53,7 +53,7 @@ Action: """ -def fake_react_code_llm(messages, stop_sequences=None) -> str: +def fake_react_code_llm(messages, stop_sequences=None, grammar=None) -> str: prompt = str(messages) if "special_marker" not in prompt: return """ @@ -119,7 +119,7 @@ final_answer(res) """ -def fake_code_llm_oneshot(messages, stop_sequences=None) -> str: +def fake_code_llm_oneshot(messages, stop_sequences=None, grammar=None) -> str: return """ Thought: I should multiply 2 by 3.6452. special_marker Code: @@ -130,7 +130,7 @@ final_answer(result) """ -def fake_code_llm_no_return(messages, stop_sequences=None) -> str: +def fake_code_llm_no_return(messages, stop_sequences=None, grammar=None) -> str: return """ Thought: I should multiply 2 by 3.6452. special_marker Code: