Agents: Small fixes in streaming to gradio + add tests (#34549)
* Better support transformers.agents in gradio: small fixes and additional tests
This commit is contained in:
@@ -1141,11 +1141,10 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
)
|
)
|
||||||
self.logger.warning("Print outputs:")
|
self.logger.warning("Print outputs:")
|
||||||
self.logger.log(32, self.state["print_outputs"])
|
self.logger.log(32, self.state["print_outputs"])
|
||||||
|
observation = "Print outputs:\n" + self.state["print_outputs"]
|
||||||
if result is not None:
|
if result is not None:
|
||||||
self.logger.warning("Last output from code snippet:")
|
self.logger.warning("Last output from code snippet:")
|
||||||
self.logger.log(32, str(result))
|
self.logger.log(32, str(result))
|
||||||
observation = "Print outputs:\n" + self.state["print_outputs"]
|
|
||||||
if result is not None:
|
|
||||||
observation += "Last output from code snippet:\n" + str(result)[:100000]
|
observation += "Last output from code snippet:\n" + str(result)[:100000]
|
||||||
current_step_logs["observation"] = observation
|
current_step_logs["observation"] = observation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -18,10 +18,18 @@ from .agent_types import AgentAudio, AgentImage, AgentText
|
|||||||
from .agents import ReactAgent
|
from .agents import ReactAgent
|
||||||
|
|
||||||
|
|
||||||
def pull_message(step_log: dict):
|
def pull_message(step_log: dict, test_mode: bool = True):
|
||||||
try:
|
try:
|
||||||
from gradio import ChatMessage
|
from gradio import ChatMessage
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
if test_mode:
|
||||||
|
|
||||||
|
class ChatMessage:
|
||||||
|
def __init__(self, role, content, metadata=None):
|
||||||
|
self.role = role
|
||||||
|
self.content = content
|
||||||
|
self.metadata = metadata
|
||||||
|
else:
|
||||||
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
||||||
|
|
||||||
if step_log.get("rationale"):
|
if step_log.get("rationale"):
|
||||||
@@ -46,30 +54,40 @@ def pull_message(step_log: dict):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def stream_to_gradio(agent: ReactAgent, task: str, **kwargs):
|
def stream_to_gradio(agent: ReactAgent, task: str, test_mode: bool = False, **kwargs):
|
||||||
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from gradio import ChatMessage
|
from gradio import ChatMessage
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
if test_mode:
|
||||||
|
|
||||||
|
class ChatMessage:
|
||||||
|
def __init__(self, role, content, metadata=None):
|
||||||
|
self.role = role
|
||||||
|
self.content = content
|
||||||
|
self.metadata = metadata
|
||||||
|
else:
|
||||||
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
||||||
|
|
||||||
for step_log in agent.run(task, stream=True, **kwargs):
|
for step_log in agent.run(task, stream=True, **kwargs):
|
||||||
if isinstance(step_log, dict):
|
if isinstance(step_log, dict):
|
||||||
for message in pull_message(step_log):
|
for message in pull_message(step_log, test_mode=test_mode):
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
if isinstance(step_log, AgentText):
|
final_answer = step_log # Last log is the run's final_answer
|
||||||
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log.to_string()}\n```")
|
|
||||||
elif isinstance(step_log, AgentImage):
|
if isinstance(final_answer, AgentText):
|
||||||
|
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```")
|
||||||
|
elif isinstance(final_answer, AgentImage):
|
||||||
yield ChatMessage(
|
yield ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content={"path": step_log.to_string(), "mime_type": "image/png"},
|
content={"path": final_answer.to_string(), "mime_type": "image/png"},
|
||||||
)
|
)
|
||||||
elif isinstance(step_log, AgentAudio):
|
elif isinstance(final_answer, AgentAudio):
|
||||||
yield ChatMessage(
|
yield ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content={"path": step_log.to_string(), "mime_type": "audio/wav"},
|
content={"path": final_answer.to_string(), "mime_type": "audio/wav"},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield ChatMessage(role="assistant", content=str(step_log))
|
yield ChatMessage(role="assistant", content=str(final_answer))
|
||||||
|
|||||||
@@ -848,6 +848,13 @@ def evaluate_ast(
|
|||||||
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_print_outputs(print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT) -> str:
|
||||||
|
if len(print_outputs) < max_len_outputs:
|
||||||
|
return print_outputs
|
||||||
|
else:
|
||||||
|
return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"
|
||||||
|
|
||||||
|
|
||||||
def evaluate_python_code(
|
def evaluate_python_code(
|
||||||
code: str,
|
code: str,
|
||||||
static_tools: Optional[Dict[str, Callable]] = None,
|
static_tools: Optional[Dict[str, Callable]] = None,
|
||||||
@@ -890,25 +897,12 @@ def evaluate_python_code(
|
|||||||
PRINT_OUTPUTS = ""
|
PRINT_OUTPUTS = ""
|
||||||
global OPERATIONS_COUNT
|
global OPERATIONS_COUNT
|
||||||
OPERATIONS_COUNT = 0
|
OPERATIONS_COUNT = 0
|
||||||
for node in expression.body:
|
|
||||||
try:
|
try:
|
||||||
|
for node in expression.body:
|
||||||
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
||||||
|
state["print_outputs"] = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT)
|
||||||
|
return result
|
||||||
except InterpreterError as e:
|
except InterpreterError as e:
|
||||||
msg = ""
|
msg = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT)
|
||||||
if len(PRINT_OUTPUTS) > 0:
|
|
||||||
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
|
|
||||||
msg += f"Print outputs:\n{PRINT_OUTPUTS}\n====\n"
|
|
||||||
else:
|
|
||||||
msg += f"Print outputs:\n{PRINT_OUTPUTS[:MAX_LEN_OUTPUT]}\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._\n====\n"
|
|
||||||
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
||||||
raise InterpreterError(msg)
|
raise InterpreterError(msg)
|
||||||
finally:
|
|
||||||
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
|
|
||||||
state["print_outputs"] = PRINT_OUTPUTS
|
|
||||||
else:
|
|
||||||
state["print_outputs"] = (
|
|
||||||
PRINT_OUTPUTS[:MAX_LEN_OUTPUT]
|
|
||||||
+ f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._"
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import ast
|
||||||
import base64
|
import base64
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
@@ -141,15 +142,19 @@ class Tool:
|
|||||||
required_attributes = {
|
required_attributes = {
|
||||||
"description": str,
|
"description": str,
|
||||||
"name": str,
|
"name": str,
|
||||||
"inputs": Dict,
|
"inputs": dict,
|
||||||
"output_type": str,
|
"output_type": str,
|
||||||
}
|
}
|
||||||
authorized_types = ["string", "integer", "number", "image", "audio", "any", "boolean"]
|
authorized_types = ["string", "integer", "number", "image", "audio", "any", "boolean"]
|
||||||
|
|
||||||
for attr, expected_type in required_attributes.items():
|
for attr, expected_type in required_attributes.items():
|
||||||
attr_value = getattr(self, attr, None)
|
attr_value = getattr(self, attr, None)
|
||||||
|
if attr_value is None:
|
||||||
|
raise TypeError(f"You must set an attribute {attr}.")
|
||||||
if not isinstance(attr_value, expected_type):
|
if not isinstance(attr_value, expected_type):
|
||||||
raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.")
|
raise TypeError(
|
||||||
|
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
|
||||||
|
)
|
||||||
for input_name, input_content in self.inputs.items():
|
for input_name, input_content in self.inputs.items():
|
||||||
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
|
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
|
||||||
assert (
|
assert (
|
||||||
@@ -248,7 +253,6 @@ class Tool:
|
|||||||
def from_hub(
|
def from_hub(
|
||||||
cls,
|
cls,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
model_repo_id: Optional[str] = None,
|
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -266,9 +270,6 @@ class Tool:
|
|||||||
Args:
|
Args:
|
||||||
repo_id (`str`):
|
repo_id (`str`):
|
||||||
The name of the repo on the Hub where your tool is defined.
|
The name of the repo on the Hub where your tool is defined.
|
||||||
model_repo_id (`str`, *optional*):
|
|
||||||
If your tool uses a model and you want to use a different model than the default, you can pass a second
|
|
||||||
repo ID or an endpoint url to this argument.
|
|
||||||
token (`str`, *optional*):
|
token (`str`, *optional*):
|
||||||
The token to identify you on hf.co. If unset, will use the token generated when running
|
The token to identify you on hf.co. If unset, will use the token generated when running
|
||||||
`huggingface-cli login` (stored in `~/.huggingface`).
|
`huggingface-cli login` (stored in `~/.huggingface`).
|
||||||
@@ -354,6 +355,9 @@ class Tool:
|
|||||||
if tool_class.output_type != custom_tool["output_type"]:
|
if tool_class.output_type != custom_tool["output_type"]:
|
||||||
tool_class.output_type = custom_tool["output_type"]
|
tool_class.output_type = custom_tool["output_type"]
|
||||||
|
|
||||||
|
if not isinstance(tool_class.inputs, dict):
|
||||||
|
tool_class.inputs = ast.literal_eval(tool_class.inputs)
|
||||||
|
|
||||||
return tool_class(**kwargs)
|
return tool_class(**kwargs)
|
||||||
|
|
||||||
def push_to_hub(
|
def push_to_hub(
|
||||||
|
|||||||
82
tests/agents/test_monitoring.py
Normal file
82
tests/agents/test_monitoring.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers.agents.agent_types import AgentImage
|
||||||
|
from transformers.agents.agents import AgentError, ReactCodeAgent, ReactJsonAgent
|
||||||
|
from transformers.agents.monitoring import stream_to_gradio
|
||||||
|
|
||||||
|
|
||||||
|
class MonitoringTester(unittest.TestCase):
|
||||||
|
def test_streaming_agent_text_output(self):
|
||||||
|
def dummy_llm_engine(prompt, **kwargs):
|
||||||
|
return """
|
||||||
|
Code:
|
||||||
|
````
|
||||||
|
final_answer('This is the final answer.')
|
||||||
|
```"""
|
||||||
|
|
||||||
|
agent = ReactCodeAgent(
|
||||||
|
tools=[],
|
||||||
|
llm_engine=dummy_llm_engine,
|
||||||
|
max_iterations=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use stream_to_gradio to capture the output
|
||||||
|
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
|
||||||
|
|
||||||
|
self.assertEqual(len(outputs), 3)
|
||||||
|
final_message = outputs[-1]
|
||||||
|
self.assertEqual(final_message.role, "assistant")
|
||||||
|
self.assertIn("This is the final answer.", final_message.content)
|
||||||
|
|
||||||
|
def test_streaming_agent_image_output(self):
|
||||||
|
def dummy_llm_engine(prompt, **kwargs):
|
||||||
|
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
||||||
|
|
||||||
|
agent = ReactJsonAgent(
|
||||||
|
tools=[],
|
||||||
|
llm_engine=dummy_llm_engine,
|
||||||
|
max_iterations=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use stream_to_gradio to capture the output
|
||||||
|
outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png"), test_mode=True))
|
||||||
|
|
||||||
|
self.assertEqual(len(outputs), 2)
|
||||||
|
final_message = outputs[-1]
|
||||||
|
self.assertEqual(final_message.role, "assistant")
|
||||||
|
self.assertIsInstance(final_message.content, dict)
|
||||||
|
self.assertEqual(final_message.content["path"], "path.png")
|
||||||
|
self.assertEqual(final_message.content["mime_type"], "image/png")
|
||||||
|
|
||||||
|
def test_streaming_with_agent_error(self):
|
||||||
|
def dummy_llm_engine(prompt, **kwargs):
|
||||||
|
raise AgentError("Simulated agent error")
|
||||||
|
|
||||||
|
agent = ReactCodeAgent(
|
||||||
|
tools=[],
|
||||||
|
llm_engine=dummy_llm_engine,
|
||||||
|
max_iterations=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use stream_to_gradio to capture the output
|
||||||
|
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
|
||||||
|
|
||||||
|
self.assertEqual(len(outputs), 3)
|
||||||
|
final_message = outputs[-1]
|
||||||
|
self.assertEqual(final_message.role, "assistant")
|
||||||
|
self.assertIn("Simulated agent error", final_message.content)
|
||||||
Reference in New Issue
Block a user