From a24a9a66f446dcb9277e31d16255536c5ce27aa6 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Mon, 29 Jul 2024 20:12:44 +0200 Subject: [PATCH] Add stream messages from agent run for gradio chatbot (#32142) * Add stream_to_gradio method for running agent in gradio demo --- docs/source/en/agents.md | 51 ++++++++++++++++++ docs/source/en/main_classes/agent.md | 4 ++ src/transformers/__init__.py | 2 + src/transformers/agents/__init__.py | 2 + src/transformers/agents/monitoring.py | 75 +++++++++++++++++++++++++++ 5 files changed, 134 insertions(+) create mode 100644 src/transformers/agents/monitoring.py diff --git a/docs/source/en/agents.md b/docs/source/en/agents.md index d1c550f5d3..f335cb678f 100644 --- a/docs/source/en/agents.md +++ b/docs/source/en/agents.md @@ -509,3 +509,54 @@ agent = ReactCodeAgent(tools=[search_tool]) agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?") ``` + +## Gradio interface + +You can leverage `gradio.Chatbot`to display your agent's thoughts using `stream_to_gradio`, here is an example: + +```py +import gradio as gr +from transformers import ( + load_tool, + ReactCodeAgent, + HfEngine, + stream_to_gradio, +) + +# Import tool from Hub +image_generation_tool = load_tool("m-ric/text-to-image") + +llm_engine = HfEngine("meta-llama/Meta-Llama-3-70B-Instruct") + +# Initialize the agent with the image generation tool +agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine) + + +def interact_with_agent(task): + messages = [] + messages.append(gr.ChatMessage(role="user", content=task)) + yield messages + for msg in stream_to_gradio(agent, task): + messages.append(msg) + yield messages + [ + gr.ChatMessage(role="assistant", content="⏳ Task not finished yet!") + ] + yield messages + + +with gr.Blocks() as demo: + text_input = gr.Textbox(lines=1, label="Chat Message", value="Make me a picture of the Statue of Liberty.") + submit = gr.Button("Run illustrator agent!") + chatbot = gr.Chatbot( + label="Agent", + type="messages", + avatar_images=( + None, + "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png", + ), + ) + submit.click(interact_with_agent, [text_input], [chatbot]) + +if __name__ == "__main__": + demo.launch() +``` \ No newline at end of file diff --git a/docs/source/en/main_classes/agent.md b/docs/source/en/main_classes/agent.md index 8376fb3648..444003615b 100644 --- a/docs/source/en/main_classes/agent.md +++ b/docs/source/en/main_classes/agent.md @@ -72,6 +72,10 @@ We provide two types of agents, based on the main [`Agent`] class: [[autodoc]] launch_gradio_demo +### stream_to_gradio + +[[autodoc]] stream_to_gradio + ### ToolCollection [[autodoc]] ToolCollection diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9108367f35..4c953bab6b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -67,6 +67,7 @@ _import_structure = { "ToolCollection", "launch_gradio_demo", "load_tool", + "stream_to_gradio", ], "audio_utils": [], "benchmark": [], @@ -4733,6 +4734,7 @@ if TYPE_CHECKING: ToolCollection, launch_gradio_demo, load_tool, + stream_to_gradio, ) from .configuration_utils import PretrainedConfig diff --git a/src/transformers/agents/__init__.py b/src/transformers/agents/__init__.py index 672977f988..c4de21a03d 100644 --- a/src/transformers/agents/__init__.py +++ b/src/transformers/agents/__init__.py @@ -26,6 +26,7 @@ from ..utils import ( _import_structure = { "agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"], "llm_engine": ["HfEngine"], + "monitoring": ["stream_to_gradio"], "tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"], } @@ -45,6 +46,7 @@ else: if TYPE_CHECKING: from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox from .llm_engine import HfEngine + from .monitoring import stream_to_gradio from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool try: diff --git a/src/transformers/agents/monitoring.py b/src/transformers/agents/monitoring.py new file mode 100644 index 0000000000..291dc1dcf1 --- /dev/null +++ b/src/transformers/agents/monitoring.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# coding=utf-8 + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .agent_types import AgentAudio, AgentImage, AgentText +from .agents import ReactAgent + + +def pull_message(step_log: dict): + try: + from gradio import ChatMessage + except ImportError: + raise ImportError("Gradio should be installed in order to launch a gradio demo.") + + if step_log.get("rationale"): + yield ChatMessage(role="assistant", content=step_log["rationale"]) + if step_log.get("tool_call"): + used_code = step_log["tool_call"]["tool_name"] == "code interpreter" + content = step_log["tool_call"]["tool_arguments"] + if used_code: + content = f"```py\n{content}\n```" + yield ChatMessage( + role="assistant", + metadata={"title": f"🛠️ Used tool {step_log['tool_call']['tool_name']}"}, + content=content, + ) + if step_log.get("observation"): + yield ChatMessage(role="assistant", content=f"```\n{step_log['observation']}\n```") + if step_log.get("error"): + yield ChatMessage( + role="assistant", + content=str(step_log["error"]), + metadata={"title": "💥 Error"}, + ) + + +def stream_to_gradio(agent: ReactAgent, task: str, **kwargs): + """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" + + try: + from gradio import ChatMessage + except ImportError: + raise ImportError("Gradio should be installed in order to launch a gradio demo.") + + for step_log in agent.run(task, stream=True, **kwargs): + if isinstance(step_log, dict): + for message in pull_message(step_log): + yield message + + if isinstance(step_log, AgentText): + yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log.to_string()}\n```") + elif isinstance(step_log, AgentImage): + yield ChatMessage( + role="assistant", + content={"path": step_log.to_string(), "mime_type": "image/png"}, + ) + elif isinstance(step_log, AgentAudio): + yield ChatMessage( + role="assistant", + content={"path": step_log.to_string(), "mime_type": "audio/wav"}, + ) + else: + yield ChatMessage(role="assistant", content=step_log)