Add token cost + runtime monitoring to Agent and HfEngine children (#34548)

* Add monitoring to Agent and HfEngine children
This commit is contained in:
Aymeric Roucher
2024-12-03 13:14:52 +01:00
committed by GitHub
parent ee37bf0d95
commit 901f504580
5 changed files with 344 additions and 126 deletions

View File

@@ -17,7 +17,8 @@
import json
import logging
import re
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from .. import is_torch_available
from ..utils import logging as transformers_logging
@@ -25,6 +26,7 @@ from ..utils.import_utils import is_pygments_available
from .agent_types import AgentAudio, AgentImage
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfApiEngine, MessageRole
from .monitoring import Monitor
from .prompts import (
DEFAULT_CODE_SYSTEM_PROMPT,
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
@@ -353,17 +355,23 @@ class Agent:
def __init__(
self,
tools: Union[List[Tool], Toolbox],
llm_engine: Callable = HfApiEngine(),
system_prompt=DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template=None,
additional_args={},
llm_engine: Callable = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
additional_args: Dict = {},
max_iterations: int = 6,
tool_parser=parse_json_tool_call,
tool_parser: Optional[Callable] = None,
add_base_tools: bool = False,
verbose: int = 0,
grammar: Dict[str, str] = None,
managed_agents: List = None,
grammar: Optional[Dict[str, str]] = None,
managed_agents: Optional[List] = None,
step_callbacks: Optional[List[Callable]] = None,
monitor_metrics: bool = True,
):
if system_prompt is None:
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
if tool_parser is None:
tool_parser = parse_json_tool_call
self.agent_name = self.__class__.__name__
self.llm_engine = llm_engine
self.system_prompt_template = system_prompt
@@ -406,6 +414,15 @@ class Agent:
elif verbose == 2:
logger.setLevel(logging.DEBUG)
# Initialize step callbacks
self.step_callbacks = step_callbacks if step_callbacks is not None else []
# Initialize Monitor if monitor_metrics is True
self.monitor = None
if monitor_metrics:
self.monitor = Monitor(self.llm_engine)
self.step_callbacks.append(self.monitor.update_metrics)
@property
def toolbox(self) -> Toolbox:
"""Get the toolbox currently available to the agent"""
@@ -578,13 +595,19 @@ class CodeAgent(Agent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
**kwargs,
):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = DEFAULT_CODE_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__(
tools=tools,
llm_engine=llm_engine,
@@ -700,14 +723,23 @@ class ReactAgent(Agent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
plan_type: Literal[tuple(SUPPORTED_PLAN_TYPES)] = SUPPORTED_PLAN_TYPES[0],
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
plan_type: Optional[str] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
if plan_type is None:
plan_type = SUPPORTED_PLAN_TYPES[0]
else:
assert plan_type in SUPPORTED_PLAN_TYPES, f"plan type {plan_type} is not supported"
super().__init__(
tools=tools,
@@ -776,16 +808,24 @@ class ReactAgent(Agent):
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
step_start_time = time.time()
step_log_entry = {"iteration": iteration, "start_time": step_start_time}
try:
step_logs = self.step()
if "final_answer" in step_logs:
final_answer = step_logs["final_answer"]
self.step(step_log_entry)
if "final_answer" in step_log_entry:
final_answer = step_log_entry["final_answer"]
except AgentError as e:
self.logger.error(e, exc_info=1)
self.logs[-1]["error"] = e
step_log_entry["error"] = e
finally:
step_end_time = time.time()
step_log_entry["step_end_time"] = step_end_time
step_log_entry["step_duration"] = step_end_time - step_start_time
self.logs.append(step_log_entry)
for callback in self.step_callbacks:
callback(step_log_entry)
iteration += 1
yield self.logs[-1]
yield step_log_entry
if final_answer is None and iteration == self.max_iterations:
error_message = "Reached max iterations."
@@ -794,6 +834,9 @@ class ReactAgent(Agent):
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
final_step_log["step_duration"] = 0
for callback in self.step_callbacks:
callback(final_step_log)
yield final_step_log
yield final_answer
@@ -805,16 +848,24 @@ class ReactAgent(Agent):
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
step_start_time = time.time()
step_log_entry = {"iteration": iteration, "start_time": step_start_time}
try:
if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
step_logs = self.step()
if "final_answer" in step_logs:
final_answer = step_logs["final_answer"]
self.step(step_log_entry)
if "final_answer" in step_log_entry:
final_answer = step_log_entry["final_answer"]
except AgentError as e:
self.logger.error(e, exc_info=1)
self.logs[-1]["error"] = e
step_log_entry["error"] = e
finally:
step_end_time = time.time()
step_log_entry["step_end_time"] = step_end_time
step_log_entry["step_duration"] = step_end_time - step_start_time
self.logs.append(step_log_entry)
for callback in self.step_callbacks:
callback(step_log_entry)
iteration += 1
if final_answer is None and iteration == self.max_iterations:
@@ -824,6 +875,9 @@ class ReactAgent(Agent):
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
final_step_log["step_duration"] = 0
for callback in self.step_callbacks:
callback(final_step_log)
return final_answer
@@ -937,13 +991,19 @@ class ReactJsonAgent(ReactAgent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = DEFAULT_REACT_JSON_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__(
tools=tools,
llm_engine=llm_engine,
@@ -954,7 +1014,7 @@ class ReactJsonAgent(ReactAgent):
**kwargs,
)
def step(self):
def step(self, log_entry: Dict[str, Any]):
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
The errors are raised here, they are caught and logged in the run() method.
@@ -965,9 +1025,7 @@ class ReactJsonAgent(ReactAgent):
self.logger.debug("===== New step =====")
# Add new step in logs
current_step_logs = {}
self.logs.append(current_step_logs)
current_step_logs["agent_memory"] = agent_memory.copy()
log_entry["agent_memory"] = agent_memory.copy()
self.logger.info("===== Calling LLM with this last message: =====")
self.logger.info(self.prompt[-1])
@@ -981,7 +1039,7 @@ class ReactJsonAgent(ReactAgent):
raise AgentGenerationError(f"Error in generating llm output: {e}.")
self.logger.debug("===== Output message of the LLM: =====")
self.logger.debug(llm_output)
current_step_logs["llm_output"] = llm_output
log_entry["llm_output"] = llm_output
# Parse
self.logger.debug("===== Extracting action =====")
@@ -992,8 +1050,8 @@ class ReactJsonAgent(ReactAgent):
except Exception as e:
raise AgentParsingError(f"Could not parse the given action: {e}.")
current_step_logs["rationale"] = rationale
current_step_logs["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
log_entry["rationale"] = rationale
log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
# Execute
self.logger.warning("=== Agent thoughts:")
@@ -1011,8 +1069,8 @@ class ReactJsonAgent(ReactAgent):
answer = arguments
else:
answer = arguments
current_step_logs["final_answer"] = answer
return current_step_logs
log_entry["final_answer"] = answer
return answer
else:
if arguments is None:
arguments = {}
@@ -1030,8 +1088,8 @@ class ReactJsonAgent(ReactAgent):
else:
updated_information = str(observation).strip()
self.logger.info(updated_information)
current_step_logs["observation"] = updated_information
return current_step_logs
log_entry["observation"] = updated_information
return log_entry
class ReactCodeAgent(ReactAgent):
@@ -1044,14 +1102,20 @@ class ReactCodeAgent(ReactAgent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__(
tools=tools,
llm_engine=llm_engine,
@@ -1075,7 +1139,7 @@ class ReactCodeAgent(ReactAgent):
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
self.custom_tools = {}
def step(self):
def step(self, log_entry: Dict[str, Any]):
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
The errors are raised here, they are caught and logged in the run() method.
@@ -1083,13 +1147,10 @@ class ReactCodeAgent(ReactAgent):
agent_memory = self.write_inner_memory_from_logs()
self.prompt = agent_memory.copy()
self.logger.debug("===== New step =====")
# Add new step in logs
current_step_logs = {}
self.logs.append(current_step_logs)
current_step_logs["agent_memory"] = agent_memory.copy()
log_entry["agent_memory"] = agent_memory.copy()
self.logger.info("===== Calling LLM with these last messages: =====")
self.logger.info(self.prompt[-2:])
@@ -1104,7 +1165,7 @@ class ReactCodeAgent(ReactAgent):
self.logger.debug("=== Output message of the LLM:")
self.logger.debug(llm_output)
current_step_logs["llm_output"] = llm_output
log_entry["llm_output"] = llm_output
# Parse
self.logger.debug("=== Extracting action ===")
@@ -1120,8 +1181,8 @@ class ReactCodeAgent(ReactAgent):
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
raise AgentParsingError(error_msg)
current_step_logs["rationale"] = rationale
current_step_logs["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
log_entry["rationale"] = rationale
log_entry["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
# Execute
self.log_rationale_code_action(rationale, code_action)
@@ -1146,7 +1207,7 @@ class ReactCodeAgent(ReactAgent):
self.logger.warning("Last output from code snippet:")
self.logger.log(32, str(result))
observation += "Last output from code snippet:\n" + str(result)[:100000]
current_step_logs["observation"] = observation
log_entry["observation"] = observation
except Exception as e:
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
if "'dict' object has no attribute 'read'" in str(e):
@@ -1156,8 +1217,11 @@ class ReactCodeAgent(ReactAgent):
if line[: len("final_answer")] == "final_answer":
self.logger.log(33, "Final answer:")
self.logger.log(32, result)
current_step_logs["final_answer"] = result
return current_step_logs
log_entry["final_answer"] = result
return result
LENGTH_TRUNCATE_REPORTS = 1000
class ManagedAgent:
@@ -1200,10 +1264,14 @@ And even if your task resolution is not successful, please return as much contex
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
content = message["content"]
if len(str(content)) < 1000 or "[FACTS LIST]" in str(content):
if len(str(content)) < LENGTH_TRUNCATE_REPORTS or "[FACTS LIST]" in str(content):
answer += "\n" + str(content) + "\n---"
else:
answer += "\n" + str(content)[:1000] + "\n(...Step was truncated because too long)...\n---"
answer += (
"\n"
+ str(content)[:LENGTH_TRUNCATE_REPORTS]
+ "\n(...Step was truncated because too long)...\n---"
)
answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'."
return answer
else:

View File

@@ -20,7 +20,12 @@ from typing import Dict, List, Optional
from huggingface_hub import InferenceClient
from .. import AutoTokenizer
from ..pipelines.base import Pipeline
from ..utils import logging
logger = logging.get_logger(__name__)
class MessageRole(str, Enum):
@@ -67,46 +72,32 @@ llama_role_conversions = {
}
class HfApiEngine:
"""A class to interact with Hugging Face's Inference API for language model interaction.
class HfEngine:
def __init__(self, model_id: Optional[str] = None):
self.last_input_token_count = None
self.last_output_token_count = None
if model_id is None:
model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
logger.warning(f"Using default model for token counting: '{model_id}'")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
except Exception as e:
logger.warning(f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead.")
self.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct")
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
def get_token_counts(self):
return {
"input_token_count": self.last_input_token_count,
"output_token_count": self.last_output_token_count,
}
Parameters:
model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3.1-8B-Instruct"`):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
token (`str`, *optional*):
The Hugging Face API token for authentication. If not provided, the class will use the token stored in the Hugging Face CLI configuration.
max_tokens (`int`, *optional*, defaults to 1500):
The maximum number of tokens allowed in the output.
timeout (`int`, *optional*, defaults to 120):
Timeout for the API request, in seconds.
Raises:
ValueError:
If the model name is not provided.
"""
def __init__(
self,
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
token: Optional[str] = None,
max_tokens: Optional[int] = 1500,
timeout: Optional[int] = 120,
def generate(
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
):
"""Initialize the HfApiEngine."""
if not model:
raise ValueError("Model name must be provided.")
self.model = model
self.client = InferenceClient(self.model, token=token, timeout=timeout)
self.max_tokens = max_tokens
raise NotImplementedError
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = [],
grammar: Optional[str] = None,
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
) -> str:
"""Process the input messages and return the model's response.
@@ -136,6 +127,57 @@ class HfApiEngine:
"Quantum mechanics is the branch of physics that studies..."
```
"""
if not isinstance(messages, List):
raise ValueError("Messages should be a list of dictionaries with 'role' and 'content' keys.")
if stop_sequences is None:
stop_sequences = []
response = self.generate(messages, stop_sequences, grammar)
self.last_input_token_count = len(self.tokenizer.apply_chat_template(messages, tokenize=True))
self.last_output_token_count = len(self.tokenizer.encode(response))
# Remove stop sequences from LLM output
for stop_seq in stop_sequences:
if response[-len(stop_seq) :] == stop_seq:
response = response[: -len(stop_seq)]
return response
class HfApiEngine(HfEngine):
"""A class to interact with Hugging Face's Inference API for language model interaction.
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
Parameters:
model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3.1-8B-Instruct"`):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
token (`str`, *optional*):
Token used by the Hugging Face API for authentication.
If not provided, the class will use the token stored in the Hugging Face CLI configuration.
max_tokens (`int`, *optional*, defaults to 1500):
The maximum number of tokens allowed in the output.
timeout (`int`, *optional*, defaults to 120):
Timeout for the API request, in seconds.
Raises:
ValueError:
If the model name is not provided.
"""
def __init__(
self,
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
token: Optional[str] = None,
max_tokens: Optional[int] = 1500,
timeout: Optional[int] = 120,
):
super().__init__(model_id=model)
self.model = model
self.client = InferenceClient(self.model, token=token, timeout=timeout)
self.max_tokens = max_tokens
def generate(
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
) -> str:
# Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
@@ -148,41 +190,40 @@ class HfApiEngine:
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens)
response = response.choices[0].message.content
# Remove stop sequences from LLM output
for stop_seq in stop_sequences:
if response[-len(stop_seq) :] == stop_seq:
response = response[: -len(stop_seq)]
return response
class TransformersEngine:
class TransformersEngine(HfEngine):
"""This engine uses a pre-initialized local text-generation pipeline."""
def __init__(self, pipeline: Pipeline):
def __init__(self, pipeline: Pipeline, model_id: Optional[str] = None):
super().__init__(model_id)
self.pipeline = pipeline
def __call__(
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
def generate(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_length: int = 1500,
) -> str:
# Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
# Get LLM output
if stop_sequences is not None and len(stop_sequences) > 0:
stop_strings = stop_sequences
else:
stop_strings = None
output = self.pipeline(
messages,
stop_strings=stop_sequences,
max_length=1500,
stop_strings=stop_strings,
max_length=max_length,
tokenizer=self.pipeline.tokenizer,
)
response = output[0]["generated_text"][-1]["content"]
# Remove stop sequences from LLM output
if stop_sequences is not None:
for stop_seq in stop_sequences:
if response[-len(stop_seq) :] == stop_seq:
response = response[: -len(stop_seq)]
return response

View File

@@ -14,8 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import logging
from .agent_types import AgentAudio, AgentImage, AgentText
from .agents import ReactAgent
logger = logging.get_logger(__name__)
def pull_message(step_log: dict, test_mode: bool = True):
@@ -54,7 +57,7 @@ def pull_message(step_log: dict, test_mode: bool = True):
)
def stream_to_gradio(agent: ReactAgent, task: str, test_mode: bool = False, **kwargs):
def stream_to_gradio(agent, task: str, test_mode: bool = False, **kwargs):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
try:
@@ -91,3 +94,24 @@ def stream_to_gradio(agent: ReactAgent, task: str, test_mode: bool = False, **kw
)
else:
yield ChatMessage(role="assistant", content=str(final_answer))
class Monitor:
def __init__(self, tracked_llm_engine):
self.step_durations = []
self.tracked_llm_engine = tracked_llm_engine
if getattr(self.tracked_llm_engine, "last_input_token_count", "Not found") != "Not found":
self.total_input_token_count = 0
self.total_output_token_count = 0
def update_metrics(self, step_log):
step_duration = step_log["step_duration"]
self.step_durations.append(step_duration)
logger.info(f"Step {len(self.step_durations)}:")
logger.info(f"- Time taken: {step_duration:.2f} seconds (valid only if step succeeded)")
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
self.total_output_token_count += self.tracked_llm_engine.last_output_token_count
logger.info(f"- Input tokens: {self.total_input_token_count}")
logger.info(f"- Output tokens: {self.total_output_token_count}")

View File

@@ -785,21 +785,22 @@ def launch_gradio_demo(tool_class: Tool):
def fn(*args, **kwargs):
return tool(*args, **kwargs)
TYPE_TO_COMPONENT_CLASS_MAPPING = {
"image": gr.Image,
"audio": gr.Audio,
"string": gr.Textbox,
"integer": gr.Textbox,
"number": gr.Textbox,
}
gradio_inputs = []
for input_name, input_details in tool_class.inputs.items():
input_type = input_details["type"]
if input_type == "image":
gradio_inputs.append(gr.Image(label=input_name))
elif input_type == "audio":
gradio_inputs.append(gr.Audio(label=input_name))
elif input_type in ["string", "integer", "number"]:
gradio_inputs.append(gr.Textbox(label=input_name))
else:
error_message = f"Input type '{input_type}' not supported."
raise ValueError(error_message)
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
new_component = input_gradio_component_class(label=input_name)
gradio_inputs.append(new_component)
gradio_output = tool_class.output_type
assert gradio_output in ["string", "image", "audio"], f"Output type '{gradio_output}' not supported."
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[tool_class.output_type]
gradio_output = output_gradio_componentclass(label=input_name)
gr.Interface(
fn=fn,

View File

@@ -21,11 +21,95 @@ from transformers.agents.monitoring import stream_to_gradio
class MonitoringTester(unittest.TestCase):
def test_code_agent_metrics(self):
class FakeLLMEngine:
def __init__(self):
self.last_input_token_count = 10
self.last_output_token_count = 20
def __call__(self, prompt, **kwargs):
return """
Code:
```py
final_answer('This is the final answer.')
```"""
agent = ReactCodeAgent(
tools=[],
llm_engine=FakeLLMEngine(),
max_iterations=1,
)
agent.run("Fake task")
self.assertEqual(agent.monitor.total_input_token_count, 10)
self.assertEqual(agent.monitor.total_output_token_count, 20)
def test_json_agent_metrics(self):
class FakeLLMEngine:
def __init__(self):
self.last_input_token_count = 10
self.last_output_token_count = 20
def __call__(self, prompt, **kwargs):
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
agent = ReactJsonAgent(
tools=[],
llm_engine=FakeLLMEngine(),
max_iterations=1,
)
agent.run("Fake task")
self.assertEqual(agent.monitor.total_input_token_count, 10)
self.assertEqual(agent.monitor.total_output_token_count, 20)
def test_code_agent_metrics_max_iterations(self):
class FakeLLMEngine:
def __init__(self):
self.last_input_token_count = 10
self.last_output_token_count = 20
def __call__(self, prompt, **kwargs):
return "Malformed answer"
agent = ReactCodeAgent(
tools=[],
llm_engine=FakeLLMEngine(),
max_iterations=1,
)
agent.run("Fake task")
self.assertEqual(agent.monitor.total_input_token_count, 20)
self.assertEqual(agent.monitor.total_output_token_count, 40)
def test_code_agent_metrics_generation_error(self):
class FakeLLMEngine:
def __init__(self):
self.last_input_token_count = 10
self.last_output_token_count = 20
def __call__(self, prompt, **kwargs):
raise AgentError
agent = ReactCodeAgent(
tools=[],
llm_engine=FakeLLMEngine(),
max_iterations=1,
)
agent.run("Fake task")
self.assertEqual(agent.monitor.total_input_token_count, 20)
self.assertEqual(agent.monitor.total_output_token_count, 40)
def test_streaming_agent_text_output(self):
def dummy_llm_engine(prompt, **kwargs):
return """
Code:
````
```py
final_answer('This is the final answer.')
```"""