Multi agents with manager (#32687)

* Add Multi agents with a hierarchical system
This commit is contained in:
Aymeric Roucher
2024-09-04 17:30:54 +02:00
committed by GitHub
parent d2dcff96f8
commit 2cb543db77
6 changed files with 192 additions and 51 deletions

View File

@@ -24,7 +24,7 @@ from ..utils import (
_import_structure = {
"agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
"agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
"llm_engine": ["HfApiEngine", "TransformersEngine"],
"monitoring": ["stream_to_gradio"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
@@ -45,7 +45,7 @@ else:
_import_structure["translation"] = ["TranslationTool"]
if TYPE_CHECKING:
from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
from .llm_engine import HfApiEngine, TransformersEngine
from .monitoring import stream_to_gradio
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool

View File

@@ -57,8 +57,11 @@ class CustomFormatter(logging.Formatter):
bold_yellow = "\x1b[33;1m"
red = "\x1b[31;20m"
green = "\x1b[32;20m"
bold_green = "\x1b[32;20;1m"
bold_red = "\x1b[31;1m"
bold_white = "\x1b[37;1m"
orange = "\x1b[38;5;214m"
bold_orange = "\x1b[38;5;214;1m"
reset = "\x1b[0m"
format = "%(message)s"
@@ -66,11 +69,14 @@ class CustomFormatter(logging.Formatter):
logging.DEBUG: grey + format + reset,
logging.INFO: format,
logging.WARNING: bold_yellow + format + reset,
31: reset + format + reset,
32: green + format + reset,
33: bold_white + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset,
31: reset + format + reset,
32: green + format + reset,
33: bold_green + format + reset,
34: bold_white + format + reset,
35: orange + format + reset,
36: bold_orange + format + reset,
}
def format(self, record):
@@ -311,12 +317,32 @@ class AgentGenerationError(AgentError):
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
if "<<tool_names>>" in prompt:
tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
return prompt
def show_agents_descriptions(managed_agents: list):
managed_agents_descriptions = """
You can also give requests to team members.
Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaning your request.
Given that this team member is a real human, you should be very verbose in your request.
Here is a list of the team members that you can call:"""
for agent in managed_agents.values():
managed_agents_descriptions += f"\n- {agent.name}: {agent.description}"
return managed_agents_descriptions
def format_prompt_with_managed_agents_descriptions(prompt_template, managed_agents=None) -> str:
if managed_agents is not None:
return prompt_template.replace("<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents))
else:
return prompt_template.replace("<<managed_agents_descriptions>>", "")
def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str:
if "<<authorized_imports>>" not in prompt_template:
raise AgentError("Tag '<<authorized_imports>>' should be provided in the prompt.")
@@ -335,8 +361,8 @@ class Agent:
tool_parser=parse_json_tool_call,
add_base_tools: bool = False,
verbose: int = 0,
memory_verbose: bool = False,
grammar: Dict[str, str] = None,
managed_agents: List = None,
):
self.agent_name = self.__class__.__name__
self.llm_engine = llm_engine
@@ -350,6 +376,10 @@ class Agent:
self.tool_parser = tool_parser
self.grammar = grammar
self.managed_agents = None
if managed_agents is not None:
self.managed_agents = {agent.name: agent for agent in managed_agents}
if isinstance(tools, Toolbox):
self._toolbox = tools
if add_base_tools:
@@ -364,10 +394,10 @@ class Agent:
self.system_prompt = format_prompt_with_tools(
self._toolbox, self.system_prompt_template, self.tool_description_template
)
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
self.prompt = None
self.logs = []
self.task = None
self.memory_verbose = memory_verbose
if verbose == 0:
logger.setLevel(logging.WARNING)
@@ -388,13 +418,14 @@ class Agent:
self.system_prompt_template,
self.tool_description_template,
)
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
if hasattr(self, "authorized_imports"):
self.system_prompt = format_prompt_with_imports(
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
)
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
self.logger.warn("======== New task ========")
self.logger.log(33, self.task)
self.logger.log(33, "======== New task ========")
self.logger.log(34, self.task)
self.logger.debug("System prompt is as follows:")
self.logger.debug(self.system_prompt)
@@ -444,12 +475,12 @@ class Agent:
if "error" in step_log or "observation" in step_log:
if "error" in step_log:
message_content = (
f"[OUTPUT OF STEP {i}] Error: "
f"[OUTPUT OF STEP {i}] -> Error:\n"
+ str(step_log["error"])
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
)
elif "observation" in step_log:
message_content = f"[OUTPUT OF STEP {i}] Observation:\n{step_log['observation']}"
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log['observation']}"
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
memory.append(tool_response_message)
@@ -477,7 +508,7 @@ class Agent:
raise AgentParsingError(
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
)
return rationale, action
return rationale.strip(), action.strip()
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
"""
@@ -488,29 +519,44 @@ class Agent:
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
arguments (Dict[str, str]): Arguments passed to the Tool.
"""
if tool_name not in self.toolbox.tools:
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(self.toolbox.tools.keys())}."
available_tools = self.toolbox.tools
if self.managed_agents is not None:
available_tools = {**available_tools, **self.managed_agents}
if tool_name not in available_tools:
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
self.logger.error(error_msg, exc_info=1)
raise AgentExecutionError(error_msg)
try:
if isinstance(arguments, str):
observation = self.toolbox.tools[tool_name](arguments)
else:
observation = available_tools[tool_name](arguments)
elif isinstance(arguments, dict):
for key, value in arguments.items():
# if the value is the name of a state variable like "image.png", replace it with the actual value
if isinstance(value, str) and value in self.state:
arguments[key] = self.state[value]
observation = self.toolbox.tools[tool_name](**arguments)
observation = available_tools[tool_name](**arguments)
else:
raise AgentExecutionError(
f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
)
return observation
except Exception as e:
if tool_name in self.toolbox.tools:
raise AgentExecutionError(
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(self.toolbox.tools[tool_name])}"
f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(available_tools[tool_name])}"
)
elif tool_name in self.managed_agents:
raise AgentExecutionError(
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
)
def log_code_action(self, code_action: str) -> None:
self.logger.warning("==== Agent is executing the code below:")
def log_rationale_code_action(self, rationale: str, code_action: str) -> None:
self.logger.warning("=== Agent thoughts:")
self.logger.log(31, rationale)
self.logger.warning(">>> Agent is executing the code below:")
if is_pygments_available():
self.logger.log(
31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord"))
@@ -612,12 +658,12 @@ class CodeAgent(Agent):
# Parse
try:
_, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
except Exception as e:
self.logger.debug(
f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
)
code_action = llm_output
rationale, code_action = "", llm_output
try:
code_action = self.parse_code_blob(code_action)
@@ -627,7 +673,7 @@ class CodeAgent(Agent):
return error_msg
# Execute
self.log_code_action(code_action)
self.log_rationale_code_action(rationale, code_action)
try:
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
output = self.python_evaluator(
@@ -813,6 +859,9 @@ Now begin!""",
"content": PROMPTS_FOR_INITIAL_PLAN[self.plan_type]["user"].format(
task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
),
answer_facts=answer_facts,
),
}
@@ -829,8 +878,8 @@ Now begin!""",
{answer_facts}
```""".strip()
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
self.logger.debug("===== Initial plan: =====")
self.logger.debug(final_plan_redaction)
self.logger.log(36, "===== Initial plan =====")
self.logger.log(35, final_plan_redaction)
else: # update plan
agent_memory = self.write_inner_memory_from_logs(
summary_mode=False
@@ -857,6 +906,9 @@ Now begin!""",
"content": PROMPTS_FOR_PLAN_UPDATE[self.plan_type]["user"].format(
task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
),
facts_update=facts_update,
remaining_steps=(self.max_iterations - iteration),
),
@@ -872,8 +924,8 @@ Now begin!""",
{facts_update}
```"""
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
self.logger.debug("===== Updated plan: =====")
self.logger.debug(final_plan_redaction)
self.logger.log(36, "===== Updated plan =====")
self.logger.log(35, final_plan_redaction)
class ReactJsonAgent(ReactAgent):
@@ -945,7 +997,9 @@ class ReactJsonAgent(ReactAgent):
current_step_logs["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
# Execute
self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}")
self.logger.warning("=== Agent thoughts:")
self.logger.log(31, rationale)
self.logger.warning(f">>> Calling tool: '{tool_name}' with arguments: {arguments}")
if tool_name == "final_answer":
if isinstance(arguments, dict):
if "answer" in arguments:
@@ -961,6 +1015,8 @@ class ReactJsonAgent(ReactAgent):
current_step_logs["final_answer"] = answer
return current_step_logs
else:
if arguments is None:
arguments = {}
observation = self.execute_tool_call(tool_name, arguments)
observation_type = type(observation)
if observation_type == AgentText:
@@ -1050,12 +1106,12 @@ class ReactCodeAgent(ReactAgent):
except Exception as e:
raise AgentGenerationError(f"Error in generating llm output: {e}.")
self.logger.debug("===== Output message of the LLM: =====")
self.logger.debug("=== Output message of the LLM:")
self.logger.debug(llm_output)
current_step_logs["llm_output"] = llm_output
# Parse
self.logger.debug("===== Extracting action =====")
self.logger.debug("=== Extracting action ===")
try:
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
except Exception as e:
@@ -1072,22 +1128,30 @@ class ReactCodeAgent(ReactAgent):
current_step_logs["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
# Execute
self.log_code_action(code_action)
self.log_rationale_code_action(rationale, code_action)
try:
result = self.python_evaluator(
code_action,
static_tools={
static_tools = {
**BASE_PYTHON_TOOLS.copy(),
**self.toolbox.tools,
},
}
if self.managed_agents is not None:
static_tools = {**static_tools, **self.managed_agents}
result = self.python_evaluator(
code_action,
static_tools=static_tools,
custom_tools=self.custom_tools,
state=self.state,
authorized_imports=self.authorized_imports,
)
information = self.state["print_outputs"]
self.logger.warning("Print outputs:")
self.logger.log(32, information)
current_step_logs["observation"] = information
self.logger.log(32, self.state["print_outputs"])
if result is not None:
self.logger.warning("Last output from code snippet:")
self.logger.log(32, str(result))
observation = "Print outputs:\n" + self.state["print_outputs"]
if result is not None:
observation += "Last output from code snippet:\n" + str(result)[:100000]
current_step_logs["observation"] = observation
except Exception as e:
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
if "'dict' object has no attribute 'read'" in str(e):
@@ -1095,7 +1159,57 @@ class ReactCodeAgent(ReactAgent):
raise AgentExecutionError(error_msg)
for line in code_action.split("\n"):
if line[: len("final_answer")] == "final_answer":
self.logger.warning(">>> Final answer:")
self.logger.log(33, "Final answer:")
self.logger.log(32, result)
current_step_logs["final_answer"] = result
return current_step_logs
class ManagedAgent:
def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False):
self.agent = agent
self.name = name
self.description = description
self.additional_prompting = additional_prompting
self.provide_run_summary = provide_run_summary
def write_full_task(self, task):
full_task = f"""You're a helpful agent named '{self.name}'.
You have been submitted this task by your manager.
---
Task:
{task}
---
You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible so that they have a clear understanding of the answer.
Your final_answer WILL HAVE to contain these parts:
### 1. Task outcome (short version):
### 2. Task outcome (extremely detailed version):
### 3. Additional context (if relevant):
Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost.
And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
<<additional_prompting>>"""
if self.additional_prompting:
full_task = full_task.replace("\n<<additional_prompting>>", self.additional_prompting).strip()
else:
full_task = full_task.replace("\n<<additional_prompting>>", "").strip()
return full_task
def __call__(self, request, **kwargs):
full_task = self.write_full_task(request)
output = self.agent.run(full_task, **kwargs)
if self.provide_run_summary:
answer = f"Here is the final answer from your managed agent '{self.name}':\n"
answer += str(output)
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
content = message["content"]
if len(str(content)) < 1000 or "[FACTS LIST]" in str(content):
answer += "\n" + str(content) + "\n---"
else:
answer += "\n" + str(content)[:1000] + "\n(...Step was truncated because too long)...\n---"
answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'."
return answer
else:
return output

View File

@@ -29,7 +29,7 @@ from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
def custom_print(*args):
return " ".join(map(str, args))
return None
BASE_PYTHON_TOOLS = {

View File

@@ -332,10 +332,10 @@ final_answer("Shanghai")
---
Task: "What is the current age of the pope, raised to the power 0.36?"
Thought: I will use the tool `search` to get the age of the pope, then raise it to the power 0.36.
Thought: I will use the tool `wiki` to get the age of the pope, then raise it to the power 0.36.
Code:
```py
pope_age = search(query="current pope age")
pope_age = wiki(query="current pope age")
print("Pope age:", pope_age)
```<end_action>
Observation:
@@ -348,16 +348,16 @@ pope_current_age = 85 ** 0.36
final_answer(pope_current_age)
```<end_action>
Above example were using notional tools that might not exist for you. You only have acces to those tools:
Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you have acces to those tools (and no other tool):
<<tool_descriptions>>
You also can perform computations in the Python code that you generate.
<<managed_agents_descriptions>>
Here are the rules you should always follow to solve your task:
1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```<end_action>' sequence, else you will fail.
2. Use only variables that you have defined!
3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = wiki({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = wiki(query="What is the place where James Bond lives?")'.
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
@@ -410,6 +410,8 @@ Task:
Your plan can leverage any of these tools:
{tool_descriptions}
{managed_agents_descriptions}
List of facts that you know:
```
{answer_facts}
@@ -453,9 +455,11 @@ USER_PROMPT_PLAN_UPDATE = """You're still working towards solving this task:
{task}
```
You have access to these tools:
You have access to these tools and only these:
{tool_descriptions}
{managed_agents_descriptions}
Here is the up to date list of facts that you know:
```
{facts_update}

View File

@@ -434,7 +434,7 @@ def evaluate_call(call, state, static_tools, custom_tools):
global PRINT_OUTPUTS
PRINT_OUTPUTS += output + "\n"
# cap the number of lines
return output
return None
else: # Assume it's a callable object
output = func(*args, **kwargs)
return output

View File

@@ -20,7 +20,14 @@ import uuid
import pytest
from transformers.agents.agent_types import AgentText
from transformers.agents.agents import AgentMaxIterationsError, CodeAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
from transformers.agents.agents import (
AgentMaxIterationsError,
CodeAgent,
ManagedAgent,
ReactCodeAgent,
ReactJsonAgent,
Toolbox,
)
from transformers.agents.default_tools import PythonInterpreterTool
from transformers.testing_utils import require_torch
@@ -235,3 +242,19 @@ Action:
)
res = agent.run("ok")
assert res[0] == 0.5
def test_init_managed_agent(self):
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
assert managed_agent.name == "managed_agent"
assert managed_agent.description == "Empty"
def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
manager_agent = ReactCodeAgent(
tools=[], llm_engine=fake_react_code_functiondef, managed_agents=[managed_agent]
)
assert "You can also give requests to team members." not in agent.system_prompt
assert "<<managed_agents_descriptions>>" not in agent.system_prompt
assert "You can also give requests to team members." in manager_agent.system_prompt