Multi agents with manager (#32687)
* Add Multi agents with a hierarchical system
This commit is contained in:
@@ -24,7 +24,7 @@ from ..utils import (
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
|
"agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
|
||||||
"llm_engine": ["HfApiEngine", "TransformersEngine"],
|
"llm_engine": ["HfApiEngine", "TransformersEngine"],
|
||||||
"monitoring": ["stream_to_gradio"],
|
"monitoring": ["stream_to_gradio"],
|
||||||
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
|
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
|
||||||
@@ -45,7 +45,7 @@ else:
|
|||||||
_import_structure["translation"] = ["TranslationTool"]
|
_import_structure["translation"] = ["TranslationTool"]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
|
from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
|
||||||
from .llm_engine import HfApiEngine, TransformersEngine
|
from .llm_engine import HfApiEngine, TransformersEngine
|
||||||
from .monitoring import stream_to_gradio
|
from .monitoring import stream_to_gradio
|
||||||
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool
|
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool
|
||||||
|
|||||||
@@ -57,8 +57,11 @@ class CustomFormatter(logging.Formatter):
|
|||||||
bold_yellow = "\x1b[33;1m"
|
bold_yellow = "\x1b[33;1m"
|
||||||
red = "\x1b[31;20m"
|
red = "\x1b[31;20m"
|
||||||
green = "\x1b[32;20m"
|
green = "\x1b[32;20m"
|
||||||
|
bold_green = "\x1b[32;20;1m"
|
||||||
bold_red = "\x1b[31;1m"
|
bold_red = "\x1b[31;1m"
|
||||||
bold_white = "\x1b[37;1m"
|
bold_white = "\x1b[37;1m"
|
||||||
|
orange = "\x1b[38;5;214m"
|
||||||
|
bold_orange = "\x1b[38;5;214;1m"
|
||||||
reset = "\x1b[0m"
|
reset = "\x1b[0m"
|
||||||
format = "%(message)s"
|
format = "%(message)s"
|
||||||
|
|
||||||
@@ -66,11 +69,14 @@ class CustomFormatter(logging.Formatter):
|
|||||||
logging.DEBUG: grey + format + reset,
|
logging.DEBUG: grey + format + reset,
|
||||||
logging.INFO: format,
|
logging.INFO: format,
|
||||||
logging.WARNING: bold_yellow + format + reset,
|
logging.WARNING: bold_yellow + format + reset,
|
||||||
31: reset + format + reset,
|
|
||||||
32: green + format + reset,
|
|
||||||
33: bold_white + format + reset,
|
|
||||||
logging.ERROR: red + format + reset,
|
logging.ERROR: red + format + reset,
|
||||||
logging.CRITICAL: bold_red + format + reset,
|
logging.CRITICAL: bold_red + format + reset,
|
||||||
|
31: reset + format + reset,
|
||||||
|
32: green + format + reset,
|
||||||
|
33: bold_green + format + reset,
|
||||||
|
34: bold_white + format + reset,
|
||||||
|
35: orange + format + reset,
|
||||||
|
36: bold_orange + format + reset,
|
||||||
}
|
}
|
||||||
|
|
||||||
def format(self, record):
|
def format(self, record):
|
||||||
@@ -311,12 +317,32 @@ class AgentGenerationError(AgentError):
|
|||||||
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
|
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
|
||||||
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
|
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
|
||||||
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
|
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
|
||||||
|
|
||||||
if "<<tool_names>>" in prompt:
|
if "<<tool_names>>" in prompt:
|
||||||
tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
|
tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
|
||||||
prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
|
prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def show_agents_descriptions(managed_agents: list):
|
||||||
|
managed_agents_descriptions = """
|
||||||
|
You can also give requests to team members.
|
||||||
|
Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaning your request.
|
||||||
|
Given that this team member is a real human, you should be very verbose in your request.
|
||||||
|
Here is a list of the team members that you can call:"""
|
||||||
|
for agent in managed_agents.values():
|
||||||
|
managed_agents_descriptions += f"\n- {agent.name}: {agent.description}"
|
||||||
|
return managed_agents_descriptions
|
||||||
|
|
||||||
|
|
||||||
|
def format_prompt_with_managed_agents_descriptions(prompt_template, managed_agents=None) -> str:
|
||||||
|
if managed_agents is not None:
|
||||||
|
return prompt_template.replace("<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents))
|
||||||
|
else:
|
||||||
|
return prompt_template.replace("<<managed_agents_descriptions>>", "")
|
||||||
|
|
||||||
|
|
||||||
def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str:
|
def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str:
|
||||||
if "<<authorized_imports>>" not in prompt_template:
|
if "<<authorized_imports>>" not in prompt_template:
|
||||||
raise AgentError("Tag '<<authorized_imports>>' should be provided in the prompt.")
|
raise AgentError("Tag '<<authorized_imports>>' should be provided in the prompt.")
|
||||||
@@ -335,8 +361,8 @@ class Agent:
|
|||||||
tool_parser=parse_json_tool_call,
|
tool_parser=parse_json_tool_call,
|
||||||
add_base_tools: bool = False,
|
add_base_tools: bool = False,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
memory_verbose: bool = False,
|
|
||||||
grammar: Dict[str, str] = None,
|
grammar: Dict[str, str] = None,
|
||||||
|
managed_agents: List = None,
|
||||||
):
|
):
|
||||||
self.agent_name = self.__class__.__name__
|
self.agent_name = self.__class__.__name__
|
||||||
self.llm_engine = llm_engine
|
self.llm_engine = llm_engine
|
||||||
@@ -350,6 +376,10 @@ class Agent:
|
|||||||
self.tool_parser = tool_parser
|
self.tool_parser = tool_parser
|
||||||
self.grammar = grammar
|
self.grammar = grammar
|
||||||
|
|
||||||
|
self.managed_agents = None
|
||||||
|
if managed_agents is not None:
|
||||||
|
self.managed_agents = {agent.name: agent for agent in managed_agents}
|
||||||
|
|
||||||
if isinstance(tools, Toolbox):
|
if isinstance(tools, Toolbox):
|
||||||
self._toolbox = tools
|
self._toolbox = tools
|
||||||
if add_base_tools:
|
if add_base_tools:
|
||||||
@@ -364,10 +394,10 @@ class Agent:
|
|||||||
self.system_prompt = format_prompt_with_tools(
|
self.system_prompt = format_prompt_with_tools(
|
||||||
self._toolbox, self.system_prompt_template, self.tool_description_template
|
self._toolbox, self.system_prompt_template, self.tool_description_template
|
||||||
)
|
)
|
||||||
|
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
|
||||||
self.prompt = None
|
self.prompt = None
|
||||||
self.logs = []
|
self.logs = []
|
||||||
self.task = None
|
self.task = None
|
||||||
self.memory_verbose = memory_verbose
|
|
||||||
|
|
||||||
if verbose == 0:
|
if verbose == 0:
|
||||||
logger.setLevel(logging.WARNING)
|
logger.setLevel(logging.WARNING)
|
||||||
@@ -388,13 +418,14 @@ class Agent:
|
|||||||
self.system_prompt_template,
|
self.system_prompt_template,
|
||||||
self.tool_description_template,
|
self.tool_description_template,
|
||||||
)
|
)
|
||||||
|
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
|
||||||
if hasattr(self, "authorized_imports"):
|
if hasattr(self, "authorized_imports"):
|
||||||
self.system_prompt = format_prompt_with_imports(
|
self.system_prompt = format_prompt_with_imports(
|
||||||
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
|
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
|
||||||
)
|
)
|
||||||
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
|
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
|
||||||
self.logger.warn("======== New task ========")
|
self.logger.log(33, "======== New task ========")
|
||||||
self.logger.log(33, self.task)
|
self.logger.log(34, self.task)
|
||||||
self.logger.debug("System prompt is as follows:")
|
self.logger.debug("System prompt is as follows:")
|
||||||
self.logger.debug(self.system_prompt)
|
self.logger.debug(self.system_prompt)
|
||||||
|
|
||||||
@@ -444,12 +475,12 @@ class Agent:
|
|||||||
if "error" in step_log or "observation" in step_log:
|
if "error" in step_log or "observation" in step_log:
|
||||||
if "error" in step_log:
|
if "error" in step_log:
|
||||||
message_content = (
|
message_content = (
|
||||||
f"[OUTPUT OF STEP {i}] Error: "
|
f"[OUTPUT OF STEP {i}] -> Error:\n"
|
||||||
+ str(step_log["error"])
|
+ str(step_log["error"])
|
||||||
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
||||||
)
|
)
|
||||||
elif "observation" in step_log:
|
elif "observation" in step_log:
|
||||||
message_content = f"[OUTPUT OF STEP {i}] Observation:\n{step_log['observation']}"
|
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log['observation']}"
|
||||||
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
|
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
|
||||||
memory.append(tool_response_message)
|
memory.append(tool_response_message)
|
||||||
|
|
||||||
@@ -477,7 +508,7 @@ class Agent:
|
|||||||
raise AgentParsingError(
|
raise AgentParsingError(
|
||||||
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
|
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
|
||||||
)
|
)
|
||||||
return rationale, action
|
return rationale.strip(), action.strip()
|
||||||
|
|
||||||
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
|
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
|
||||||
"""
|
"""
|
||||||
@@ -488,29 +519,44 @@ class Agent:
|
|||||||
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
||||||
arguments (Dict[str, str]): Arguments passed to the Tool.
|
arguments (Dict[str, str]): Arguments passed to the Tool.
|
||||||
"""
|
"""
|
||||||
if tool_name not in self.toolbox.tools:
|
available_tools = self.toolbox.tools
|
||||||
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(self.toolbox.tools.keys())}."
|
if self.managed_agents is not None:
|
||||||
|
available_tools = {**available_tools, **self.managed_agents}
|
||||||
|
if tool_name not in available_tools:
|
||||||
|
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
||||||
self.logger.error(error_msg, exc_info=1)
|
self.logger.error(error_msg, exc_info=1)
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(arguments, str):
|
if isinstance(arguments, str):
|
||||||
observation = self.toolbox.tools[tool_name](arguments)
|
observation = available_tools[tool_name](arguments)
|
||||||
else:
|
elif isinstance(arguments, dict):
|
||||||
for key, value in arguments.items():
|
for key, value in arguments.items():
|
||||||
# if the value is the name of a state variable like "image.png", replace it with the actual value
|
# if the value is the name of a state variable like "image.png", replace it with the actual value
|
||||||
if isinstance(value, str) and value in self.state:
|
if isinstance(value, str) and value in self.state:
|
||||||
arguments[key] = self.state[value]
|
arguments[key] = self.state[value]
|
||||||
observation = self.toolbox.tools[tool_name](**arguments)
|
observation = available_tools[tool_name](**arguments)
|
||||||
|
else:
|
||||||
|
raise AgentExecutionError(
|
||||||
|
f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
||||||
|
)
|
||||||
return observation
|
return observation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if tool_name in self.toolbox.tools:
|
||||||
raise AgentExecutionError(
|
raise AgentExecutionError(
|
||||||
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
||||||
f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(self.toolbox.tools[tool_name])}"
|
f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(available_tools[tool_name])}"
|
||||||
|
)
|
||||||
|
elif tool_name in self.managed_agents:
|
||||||
|
raise AgentExecutionError(
|
||||||
|
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
|
||||||
|
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def log_code_action(self, code_action: str) -> None:
|
def log_rationale_code_action(self, rationale: str, code_action: str) -> None:
|
||||||
self.logger.warning("==== Agent is executing the code below:")
|
self.logger.warning("=== Agent thoughts:")
|
||||||
|
self.logger.log(31, rationale)
|
||||||
|
self.logger.warning(">>> Agent is executing the code below:")
|
||||||
if is_pygments_available():
|
if is_pygments_available():
|
||||||
self.logger.log(
|
self.logger.log(
|
||||||
31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord"))
|
31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord"))
|
||||||
@@ -612,12 +658,12 @@ class CodeAgent(Agent):
|
|||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
try:
|
try:
|
||||||
_, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
|
f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
|
||||||
)
|
)
|
||||||
code_action = llm_output
|
rationale, code_action = "", llm_output
|
||||||
|
|
||||||
try:
|
try:
|
||||||
code_action = self.parse_code_blob(code_action)
|
code_action = self.parse_code_blob(code_action)
|
||||||
@@ -627,7 +673,7 @@ class CodeAgent(Agent):
|
|||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
self.log_code_action(code_action)
|
self.log_rationale_code_action(rationale, code_action)
|
||||||
try:
|
try:
|
||||||
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
|
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
|
||||||
output = self.python_evaluator(
|
output = self.python_evaluator(
|
||||||
@@ -813,6 +859,9 @@ Now begin!""",
|
|||||||
"content": PROMPTS_FOR_INITIAL_PLAN[self.plan_type]["user"].format(
|
"content": PROMPTS_FOR_INITIAL_PLAN[self.plan_type]["user"].format(
|
||||||
task=task,
|
task=task,
|
||||||
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
||||||
|
managed_agents_descriptions=(
|
||||||
|
show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
|
||||||
|
),
|
||||||
answer_facts=answer_facts,
|
answer_facts=answer_facts,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
@@ -829,8 +878,8 @@ Now begin!""",
|
|||||||
{answer_facts}
|
{answer_facts}
|
||||||
```""".strip()
|
```""".strip()
|
||||||
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
||||||
self.logger.debug("===== Initial plan: =====")
|
self.logger.log(36, "===== Initial plan =====")
|
||||||
self.logger.debug(final_plan_redaction)
|
self.logger.log(35, final_plan_redaction)
|
||||||
else: # update plan
|
else: # update plan
|
||||||
agent_memory = self.write_inner_memory_from_logs(
|
agent_memory = self.write_inner_memory_from_logs(
|
||||||
summary_mode=False
|
summary_mode=False
|
||||||
@@ -857,6 +906,9 @@ Now begin!""",
|
|||||||
"content": PROMPTS_FOR_PLAN_UPDATE[self.plan_type]["user"].format(
|
"content": PROMPTS_FOR_PLAN_UPDATE[self.plan_type]["user"].format(
|
||||||
task=task,
|
task=task,
|
||||||
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
||||||
|
managed_agents_descriptions=(
|
||||||
|
show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
|
||||||
|
),
|
||||||
facts_update=facts_update,
|
facts_update=facts_update,
|
||||||
remaining_steps=(self.max_iterations - iteration),
|
remaining_steps=(self.max_iterations - iteration),
|
||||||
),
|
),
|
||||||
@@ -872,8 +924,8 @@ Now begin!""",
|
|||||||
{facts_update}
|
{facts_update}
|
||||||
```"""
|
```"""
|
||||||
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
||||||
self.logger.debug("===== Updated plan: =====")
|
self.logger.log(36, "===== Updated plan =====")
|
||||||
self.logger.debug(final_plan_redaction)
|
self.logger.log(35, final_plan_redaction)
|
||||||
|
|
||||||
|
|
||||||
class ReactJsonAgent(ReactAgent):
|
class ReactJsonAgent(ReactAgent):
|
||||||
@@ -945,7 +997,9 @@ class ReactJsonAgent(ReactAgent):
|
|||||||
current_step_logs["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
|
current_step_logs["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}")
|
self.logger.warning("=== Agent thoughts:")
|
||||||
|
self.logger.log(31, rationale)
|
||||||
|
self.logger.warning(f">>> Calling tool: '{tool_name}' with arguments: {arguments}")
|
||||||
if tool_name == "final_answer":
|
if tool_name == "final_answer":
|
||||||
if isinstance(arguments, dict):
|
if isinstance(arguments, dict):
|
||||||
if "answer" in arguments:
|
if "answer" in arguments:
|
||||||
@@ -961,6 +1015,8 @@ class ReactJsonAgent(ReactAgent):
|
|||||||
current_step_logs["final_answer"] = answer
|
current_step_logs["final_answer"] = answer
|
||||||
return current_step_logs
|
return current_step_logs
|
||||||
else:
|
else:
|
||||||
|
if arguments is None:
|
||||||
|
arguments = {}
|
||||||
observation = self.execute_tool_call(tool_name, arguments)
|
observation = self.execute_tool_call(tool_name, arguments)
|
||||||
observation_type = type(observation)
|
observation_type = type(observation)
|
||||||
if observation_type == AgentText:
|
if observation_type == AgentText:
|
||||||
@@ -1050,12 +1106,12 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||||
|
|
||||||
self.logger.debug("===== Output message of the LLM: =====")
|
self.logger.debug("=== Output message of the LLM:")
|
||||||
self.logger.debug(llm_output)
|
self.logger.debug(llm_output)
|
||||||
current_step_logs["llm_output"] = llm_output
|
current_step_logs["llm_output"] = llm_output
|
||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
self.logger.debug("===== Extracting action =====")
|
self.logger.debug("=== Extracting action ===")
|
||||||
try:
|
try:
|
||||||
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1072,22 +1128,30 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
current_step_logs["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
|
current_step_logs["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
self.log_code_action(code_action)
|
self.log_rationale_code_action(rationale, code_action)
|
||||||
try:
|
try:
|
||||||
result = self.python_evaluator(
|
|
||||||
code_action,
|
|
||||||
static_tools = {
|
static_tools = {
|
||||||
**BASE_PYTHON_TOOLS.copy(),
|
**BASE_PYTHON_TOOLS.copy(),
|
||||||
**self.toolbox.tools,
|
**self.toolbox.tools,
|
||||||
},
|
}
|
||||||
|
if self.managed_agents is not None:
|
||||||
|
static_tools = {**static_tools, **self.managed_agents}
|
||||||
|
result = self.python_evaluator(
|
||||||
|
code_action,
|
||||||
|
static_tools=static_tools,
|
||||||
custom_tools=self.custom_tools,
|
custom_tools=self.custom_tools,
|
||||||
state=self.state,
|
state=self.state,
|
||||||
authorized_imports=self.authorized_imports,
|
authorized_imports=self.authorized_imports,
|
||||||
)
|
)
|
||||||
information = self.state["print_outputs"]
|
|
||||||
self.logger.warning("Print outputs:")
|
self.logger.warning("Print outputs:")
|
||||||
self.logger.log(32, information)
|
self.logger.log(32, self.state["print_outputs"])
|
||||||
current_step_logs["observation"] = information
|
if result is not None:
|
||||||
|
self.logger.warning("Last output from code snippet:")
|
||||||
|
self.logger.log(32, str(result))
|
||||||
|
observation = "Print outputs:\n" + self.state["print_outputs"]
|
||||||
|
if result is not None:
|
||||||
|
observation += "Last output from code snippet:\n" + str(result)[:100000]
|
||||||
|
current_step_logs["observation"] = observation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
|
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
|
||||||
if "'dict' object has no attribute 'read'" in str(e):
|
if "'dict' object has no attribute 'read'" in str(e):
|
||||||
@@ -1095,7 +1159,57 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg)
|
||||||
for line in code_action.split("\n"):
|
for line in code_action.split("\n"):
|
||||||
if line[: len("final_answer")] == "final_answer":
|
if line[: len("final_answer")] == "final_answer":
|
||||||
self.logger.warning(">>> Final answer:")
|
self.logger.log(33, "Final answer:")
|
||||||
self.logger.log(32, result)
|
self.logger.log(32, result)
|
||||||
current_step_logs["final_answer"] = result
|
current_step_logs["final_answer"] = result
|
||||||
return current_step_logs
|
return current_step_logs
|
||||||
|
|
||||||
|
|
||||||
|
class ManagedAgent:
|
||||||
|
def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False):
|
||||||
|
self.agent = agent
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.additional_prompting = additional_prompting
|
||||||
|
self.provide_run_summary = provide_run_summary
|
||||||
|
|
||||||
|
def write_full_task(self, task):
|
||||||
|
full_task = f"""You're a helpful agent named '{self.name}'.
|
||||||
|
You have been submitted this task by your manager.
|
||||||
|
---
|
||||||
|
Task:
|
||||||
|
{task}
|
||||||
|
---
|
||||||
|
You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible so that they have a clear understanding of the answer.
|
||||||
|
|
||||||
|
Your final_answer WILL HAVE to contain these parts:
|
||||||
|
### 1. Task outcome (short version):
|
||||||
|
### 2. Task outcome (extremely detailed version):
|
||||||
|
### 3. Additional context (if relevant):
|
||||||
|
|
||||||
|
Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost.
|
||||||
|
And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
|
||||||
|
<<additional_prompting>>"""
|
||||||
|
if self.additional_prompting:
|
||||||
|
full_task = full_task.replace("\n<<additional_prompting>>", self.additional_prompting).strip()
|
||||||
|
else:
|
||||||
|
full_task = full_task.replace("\n<<additional_prompting>>", "").strip()
|
||||||
|
return full_task
|
||||||
|
|
||||||
|
def __call__(self, request, **kwargs):
|
||||||
|
full_task = self.write_full_task(request)
|
||||||
|
output = self.agent.run(full_task, **kwargs)
|
||||||
|
if self.provide_run_summary:
|
||||||
|
answer = f"Here is the final answer from your managed agent '{self.name}':\n"
|
||||||
|
answer += str(output)
|
||||||
|
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
|
||||||
|
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
|
||||||
|
content = message["content"]
|
||||||
|
if len(str(content)) < 1000 or "[FACTS LIST]" in str(content):
|
||||||
|
answer += "\n" + str(content) + "\n---"
|
||||||
|
else:
|
||||||
|
answer += "\n" + str(content)[:1000] + "\n(...Step was truncated because too long)...\n---"
|
||||||
|
answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'."
|
||||||
|
return answer
|
||||||
|
else:
|
||||||
|
return output
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
|
|||||||
|
|
||||||
|
|
||||||
def custom_print(*args):
|
def custom_print(*args):
|
||||||
return " ".join(map(str, args))
|
return None
|
||||||
|
|
||||||
|
|
||||||
BASE_PYTHON_TOOLS = {
|
BASE_PYTHON_TOOLS = {
|
||||||
|
|||||||
@@ -332,10 +332,10 @@ final_answer("Shanghai")
|
|||||||
---
|
---
|
||||||
Task: "What is the current age of the pope, raised to the power 0.36?"
|
Task: "What is the current age of the pope, raised to the power 0.36?"
|
||||||
|
|
||||||
Thought: I will use the tool `search` to get the age of the pope, then raise it to the power 0.36.
|
Thought: I will use the tool `wiki` to get the age of the pope, then raise it to the power 0.36.
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
pope_age = search(query="current pope age")
|
pope_age = wiki(query="current pope age")
|
||||||
print("Pope age:", pope_age)
|
print("Pope age:", pope_age)
|
||||||
```<end_action>
|
```<end_action>
|
||||||
Observation:
|
Observation:
|
||||||
@@ -348,16 +348,16 @@ pope_current_age = 85 ** 0.36
|
|||||||
final_answer(pope_current_age)
|
final_answer(pope_current_age)
|
||||||
```<end_action>
|
```<end_action>
|
||||||
|
|
||||||
Above example were using notional tools that might not exist for you. You only have acces to those tools:
|
Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you have acces to those tools (and no other tool):
|
||||||
|
|
||||||
<<tool_descriptions>>
|
<<tool_descriptions>>
|
||||||
|
|
||||||
You also can perform computations in the Python code that you generate.
|
<<managed_agents_descriptions>>
|
||||||
|
|
||||||
Here are the rules you should always follow to solve your task:
|
Here are the rules you should always follow to solve your task:
|
||||||
1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```<end_action>' sequence, else you will fail.
|
1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```<end_action>' sequence, else you will fail.
|
||||||
2. Use only variables that you have defined!
|
2. Use only variables that you have defined!
|
||||||
3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
|
3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = wiki({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = wiki(query="What is the place where James Bond lives?")'.
|
||||||
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
|
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
|
||||||
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
|
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
|
||||||
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
||||||
@@ -410,6 +410,8 @@ Task:
|
|||||||
Your plan can leverage any of these tools:
|
Your plan can leverage any of these tools:
|
||||||
{tool_descriptions}
|
{tool_descriptions}
|
||||||
|
|
||||||
|
{managed_agents_descriptions}
|
||||||
|
|
||||||
List of facts that you know:
|
List of facts that you know:
|
||||||
```
|
```
|
||||||
{answer_facts}
|
{answer_facts}
|
||||||
@@ -453,9 +455,11 @@ USER_PROMPT_PLAN_UPDATE = """You're still working towards solving this task:
|
|||||||
{task}
|
{task}
|
||||||
```
|
```
|
||||||
|
|
||||||
You have access to these tools:
|
You have access to these tools and only these:
|
||||||
{tool_descriptions}
|
{tool_descriptions}
|
||||||
|
|
||||||
|
{managed_agents_descriptions}
|
||||||
|
|
||||||
Here is the up to date list of facts that you know:
|
Here is the up to date list of facts that you know:
|
||||||
```
|
```
|
||||||
{facts_update}
|
{facts_update}
|
||||||
|
|||||||
@@ -434,7 +434,7 @@ def evaluate_call(call, state, static_tools, custom_tools):
|
|||||||
global PRINT_OUTPUTS
|
global PRINT_OUTPUTS
|
||||||
PRINT_OUTPUTS += output + "\n"
|
PRINT_OUTPUTS += output + "\n"
|
||||||
# cap the number of lines
|
# cap the number of lines
|
||||||
return output
|
return None
|
||||||
else: # Assume it's a callable object
|
else: # Assume it's a callable object
|
||||||
output = func(*args, **kwargs)
|
output = func(*args, **kwargs)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -20,7 +20,14 @@ import uuid
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers.agents.agent_types import AgentText
|
from transformers.agents.agent_types import AgentText
|
||||||
from transformers.agents.agents import AgentMaxIterationsError, CodeAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
|
from transformers.agents.agents import (
|
||||||
|
AgentMaxIterationsError,
|
||||||
|
CodeAgent,
|
||||||
|
ManagedAgent,
|
||||||
|
ReactCodeAgent,
|
||||||
|
ReactJsonAgent,
|
||||||
|
Toolbox,
|
||||||
|
)
|
||||||
from transformers.agents.default_tools import PythonInterpreterTool
|
from transformers.agents.default_tools import PythonInterpreterTool
|
||||||
from transformers.testing_utils import require_torch
|
from transformers.testing_utils import require_torch
|
||||||
|
|
||||||
@@ -235,3 +242,19 @@ Action:
|
|||||||
)
|
)
|
||||||
res = agent.run("ok")
|
res = agent.run("ok")
|
||||||
assert res[0] == 0.5
|
assert res[0] == 0.5
|
||||||
|
|
||||||
|
def test_init_managed_agent(self):
|
||||||
|
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
|
||||||
|
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
|
||||||
|
assert managed_agent.name == "managed_agent"
|
||||||
|
assert managed_agent.description == "Empty"
|
||||||
|
|
||||||
|
def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
|
||||||
|
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
|
||||||
|
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
|
||||||
|
manager_agent = ReactCodeAgent(
|
||||||
|
tools=[], llm_engine=fake_react_code_functiondef, managed_agents=[managed_agent]
|
||||||
|
)
|
||||||
|
assert "You can also give requests to team members." not in agent.system_prompt
|
||||||
|
assert "<<managed_agents_descriptions>>" not in agent.system_prompt
|
||||||
|
assert "You can also give requests to team members." in manager_agent.system_prompt
|
||||||
|
|||||||
Reference in New Issue
Block a user