Add token cost + runtime monitoring to Agent and HfEngine children (#34548)
* Add monitoring to Agent and HfEngine children
This commit is contained in:
@@ -17,7 +17,8 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from .. import is_torch_available
|
||||
from ..utils import logging as transformers_logging
|
||||
@@ -25,6 +26,7 @@ from ..utils.import_utils import is_pygments_available
|
||||
from .agent_types import AgentAudio, AgentImage
|
||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
||||
from .llm_engine import HfApiEngine, MessageRole
|
||||
from .monitoring import Monitor
|
||||
from .prompts import (
|
||||
DEFAULT_CODE_SYSTEM_PROMPT,
|
||||
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
@@ -353,17 +355,23 @@ class Agent:
|
||||
def __init__(
|
||||
self,
|
||||
tools: Union[List[Tool], Toolbox],
|
||||
llm_engine: Callable = HfApiEngine(),
|
||||
system_prompt=DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template=None,
|
||||
additional_args={},
|
||||
llm_engine: Callable = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
tool_description_template: Optional[str] = None,
|
||||
additional_args: Dict = {},
|
||||
max_iterations: int = 6,
|
||||
tool_parser=parse_json_tool_call,
|
||||
tool_parser: Optional[Callable] = None,
|
||||
add_base_tools: bool = False,
|
||||
verbose: int = 0,
|
||||
grammar: Dict[str, str] = None,
|
||||
managed_agents: List = None,
|
||||
grammar: Optional[Dict[str, str]] = None,
|
||||
managed_agents: Optional[List] = None,
|
||||
step_callbacks: Optional[List[Callable]] = None,
|
||||
monitor_metrics: bool = True,
|
||||
):
|
||||
if system_prompt is None:
|
||||
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
|
||||
if tool_parser is None:
|
||||
tool_parser = parse_json_tool_call
|
||||
self.agent_name = self.__class__.__name__
|
||||
self.llm_engine = llm_engine
|
||||
self.system_prompt_template = system_prompt
|
||||
@@ -406,6 +414,15 @@ class Agent:
|
||||
elif verbose == 2:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Initialize step callbacks
|
||||
self.step_callbacks = step_callbacks if step_callbacks is not None else []
|
||||
|
||||
# Initialize Monitor if monitor_metrics is True
|
||||
self.monitor = None
|
||||
if monitor_metrics:
|
||||
self.monitor = Monitor(self.llm_engine)
|
||||
self.step_callbacks.append(self.monitor.update_metrics)
|
||||
|
||||
@property
|
||||
def toolbox(self) -> Toolbox:
|
||||
"""Get the toolbox currently available to the agent"""
|
||||
@@ -578,13 +595,19 @@ class CodeAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[Tool],
|
||||
llm_engine: Callable = HfApiEngine(),
|
||||
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
grammar: Dict[str, str] = None,
|
||||
llm_engine: Optional[Callable] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
tool_description_template: Optional[str] = None,
|
||||
grammar: Optional[Dict[str, str]] = None,
|
||||
additional_authorized_imports: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if llm_engine is None:
|
||||
llm_engine = HfApiEngine()
|
||||
if system_prompt is None:
|
||||
system_prompt = DEFAULT_CODE_SYSTEM_PROMPT
|
||||
if tool_description_template is None:
|
||||
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||
super().__init__(
|
||||
tools=tools,
|
||||
llm_engine=llm_engine,
|
||||
@@ -700,14 +723,23 @@ class ReactAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[Tool],
|
||||
llm_engine: Callable = HfApiEngine(),
|
||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
grammar: Dict[str, str] = None,
|
||||
plan_type: Literal[tuple(SUPPORTED_PLAN_TYPES)] = SUPPORTED_PLAN_TYPES[0],
|
||||
llm_engine: Optional[Callable] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
tool_description_template: Optional[str] = None,
|
||||
grammar: Optional[Dict[str, str]] = None,
|
||||
plan_type: Optional[str] = None,
|
||||
planning_interval: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if llm_engine is None:
|
||||
llm_engine = HfApiEngine()
|
||||
if system_prompt is None:
|
||||
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
|
||||
if tool_description_template is None:
|
||||
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||
if plan_type is None:
|
||||
plan_type = SUPPORTED_PLAN_TYPES[0]
|
||||
else:
|
||||
assert plan_type in SUPPORTED_PLAN_TYPES, f"plan type {plan_type} is not supported"
|
||||
super().__init__(
|
||||
tools=tools,
|
||||
@@ -776,16 +808,24 @@ class ReactAgent(Agent):
|
||||
final_answer = None
|
||||
iteration = 0
|
||||
while final_answer is None and iteration < self.max_iterations:
|
||||
step_start_time = time.time()
|
||||
step_log_entry = {"iteration": iteration, "start_time": step_start_time}
|
||||
try:
|
||||
step_logs = self.step()
|
||||
if "final_answer" in step_logs:
|
||||
final_answer = step_logs["final_answer"]
|
||||
self.step(step_log_entry)
|
||||
if "final_answer" in step_log_entry:
|
||||
final_answer = step_log_entry["final_answer"]
|
||||
except AgentError as e:
|
||||
self.logger.error(e, exc_info=1)
|
||||
self.logs[-1]["error"] = e
|
||||
step_log_entry["error"] = e
|
||||
finally:
|
||||
step_end_time = time.time()
|
||||
step_log_entry["step_end_time"] = step_end_time
|
||||
step_log_entry["step_duration"] = step_end_time - step_start_time
|
||||
self.logs.append(step_log_entry)
|
||||
for callback in self.step_callbacks:
|
||||
callback(step_log_entry)
|
||||
iteration += 1
|
||||
yield self.logs[-1]
|
||||
yield step_log_entry
|
||||
|
||||
if final_answer is None and iteration == self.max_iterations:
|
||||
error_message = "Reached max iterations."
|
||||
@@ -794,6 +834,9 @@ class ReactAgent(Agent):
|
||||
self.logger.error(error_message, exc_info=1)
|
||||
final_answer = self.provide_final_answer(task)
|
||||
final_step_log["final_answer"] = final_answer
|
||||
final_step_log["step_duration"] = 0
|
||||
for callback in self.step_callbacks:
|
||||
callback(final_step_log)
|
||||
yield final_step_log
|
||||
|
||||
yield final_answer
|
||||
@@ -805,16 +848,24 @@ class ReactAgent(Agent):
|
||||
final_answer = None
|
||||
iteration = 0
|
||||
while final_answer is None and iteration < self.max_iterations:
|
||||
step_start_time = time.time()
|
||||
step_log_entry = {"iteration": iteration, "start_time": step_start_time}
|
||||
try:
|
||||
if self.planning_interval is not None and iteration % self.planning_interval == 0:
|
||||
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
|
||||
step_logs = self.step()
|
||||
if "final_answer" in step_logs:
|
||||
final_answer = step_logs["final_answer"]
|
||||
self.step(step_log_entry)
|
||||
if "final_answer" in step_log_entry:
|
||||
final_answer = step_log_entry["final_answer"]
|
||||
except AgentError as e:
|
||||
self.logger.error(e, exc_info=1)
|
||||
self.logs[-1]["error"] = e
|
||||
step_log_entry["error"] = e
|
||||
finally:
|
||||
step_end_time = time.time()
|
||||
step_log_entry["step_end_time"] = step_end_time
|
||||
step_log_entry["step_duration"] = step_end_time - step_start_time
|
||||
self.logs.append(step_log_entry)
|
||||
for callback in self.step_callbacks:
|
||||
callback(step_log_entry)
|
||||
iteration += 1
|
||||
|
||||
if final_answer is None and iteration == self.max_iterations:
|
||||
@@ -824,6 +875,9 @@ class ReactAgent(Agent):
|
||||
self.logger.error(error_message, exc_info=1)
|
||||
final_answer = self.provide_final_answer(task)
|
||||
final_step_log["final_answer"] = final_answer
|
||||
final_step_log["step_duration"] = 0
|
||||
for callback in self.step_callbacks:
|
||||
callback(final_step_log)
|
||||
|
||||
return final_answer
|
||||
|
||||
@@ -937,13 +991,19 @@ class ReactJsonAgent(ReactAgent):
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[Tool],
|
||||
llm_engine: Callable = HfApiEngine(),
|
||||
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
grammar: Dict[str, str] = None,
|
||||
llm_engine: Optional[Callable] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
tool_description_template: Optional[str] = None,
|
||||
grammar: Optional[Dict[str, str]] = None,
|
||||
planning_interval: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if llm_engine is None:
|
||||
llm_engine = HfApiEngine()
|
||||
if system_prompt is None:
|
||||
system_prompt = DEFAULT_REACT_JSON_SYSTEM_PROMPT
|
||||
if tool_description_template is None:
|
||||
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||
super().__init__(
|
||||
tools=tools,
|
||||
llm_engine=llm_engine,
|
||||
@@ -954,7 +1014,7 @@ class ReactJsonAgent(ReactAgent):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def step(self):
|
||||
def step(self, log_entry: Dict[str, Any]):
|
||||
"""
|
||||
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
||||
The errors are raised here, they are caught and logged in the run() method.
|
||||
@@ -965,9 +1025,7 @@ class ReactJsonAgent(ReactAgent):
|
||||
self.logger.debug("===== New step =====")
|
||||
|
||||
# Add new step in logs
|
||||
current_step_logs = {}
|
||||
self.logs.append(current_step_logs)
|
||||
current_step_logs["agent_memory"] = agent_memory.copy()
|
||||
log_entry["agent_memory"] = agent_memory.copy()
|
||||
|
||||
self.logger.info("===== Calling LLM with this last message: =====")
|
||||
self.logger.info(self.prompt[-1])
|
||||
@@ -981,7 +1039,7 @@ class ReactJsonAgent(ReactAgent):
|
||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||
self.logger.debug("===== Output message of the LLM: =====")
|
||||
self.logger.debug(llm_output)
|
||||
current_step_logs["llm_output"] = llm_output
|
||||
log_entry["llm_output"] = llm_output
|
||||
|
||||
# Parse
|
||||
self.logger.debug("===== Extracting action =====")
|
||||
@@ -992,8 +1050,8 @@ class ReactJsonAgent(ReactAgent):
|
||||
except Exception as e:
|
||||
raise AgentParsingError(f"Could not parse the given action: {e}.")
|
||||
|
||||
current_step_logs["rationale"] = rationale
|
||||
current_step_logs["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
|
||||
log_entry["rationale"] = rationale
|
||||
log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
|
||||
|
||||
# Execute
|
||||
self.logger.warning("=== Agent thoughts:")
|
||||
@@ -1011,8 +1069,8 @@ class ReactJsonAgent(ReactAgent):
|
||||
answer = arguments
|
||||
else:
|
||||
answer = arguments
|
||||
current_step_logs["final_answer"] = answer
|
||||
return current_step_logs
|
||||
log_entry["final_answer"] = answer
|
||||
return answer
|
||||
else:
|
||||
if arguments is None:
|
||||
arguments = {}
|
||||
@@ -1030,8 +1088,8 @@ class ReactJsonAgent(ReactAgent):
|
||||
else:
|
||||
updated_information = str(observation).strip()
|
||||
self.logger.info(updated_information)
|
||||
current_step_logs["observation"] = updated_information
|
||||
return current_step_logs
|
||||
log_entry["observation"] = updated_information
|
||||
return log_entry
|
||||
|
||||
|
||||
class ReactCodeAgent(ReactAgent):
|
||||
@@ -1044,14 +1102,20 @@ class ReactCodeAgent(ReactAgent):
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[Tool],
|
||||
llm_engine: Callable = HfApiEngine(),
|
||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
grammar: Dict[str, str] = None,
|
||||
llm_engine: Optional[Callable] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
tool_description_template: Optional[str] = None,
|
||||
grammar: Optional[Dict[str, str]] = None,
|
||||
additional_authorized_imports: Optional[List[str]] = None,
|
||||
planning_interval: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if llm_engine is None:
|
||||
llm_engine = HfApiEngine()
|
||||
if system_prompt is None:
|
||||
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
|
||||
if tool_description_template is None:
|
||||
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||
super().__init__(
|
||||
tools=tools,
|
||||
llm_engine=llm_engine,
|
||||
@@ -1075,7 +1139,7 @@ class ReactCodeAgent(ReactAgent):
|
||||
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
|
||||
self.custom_tools = {}
|
||||
|
||||
def step(self):
|
||||
def step(self, log_entry: Dict[str, Any]):
|
||||
"""
|
||||
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
||||
The errors are raised here, they are caught and logged in the run() method.
|
||||
@@ -1083,13 +1147,10 @@ class ReactCodeAgent(ReactAgent):
|
||||
agent_memory = self.write_inner_memory_from_logs()
|
||||
|
||||
self.prompt = agent_memory.copy()
|
||||
|
||||
self.logger.debug("===== New step =====")
|
||||
|
||||
# Add new step in logs
|
||||
current_step_logs = {}
|
||||
self.logs.append(current_step_logs)
|
||||
current_step_logs["agent_memory"] = agent_memory.copy()
|
||||
log_entry["agent_memory"] = agent_memory.copy()
|
||||
|
||||
self.logger.info("===== Calling LLM with these last messages: =====")
|
||||
self.logger.info(self.prompt[-2:])
|
||||
@@ -1104,7 +1165,7 @@ class ReactCodeAgent(ReactAgent):
|
||||
|
||||
self.logger.debug("=== Output message of the LLM:")
|
||||
self.logger.debug(llm_output)
|
||||
current_step_logs["llm_output"] = llm_output
|
||||
log_entry["llm_output"] = llm_output
|
||||
|
||||
# Parse
|
||||
self.logger.debug("=== Extracting action ===")
|
||||
@@ -1120,8 +1181,8 @@ class ReactCodeAgent(ReactAgent):
|
||||
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
|
||||
raise AgentParsingError(error_msg)
|
||||
|
||||
current_step_logs["rationale"] = rationale
|
||||
current_step_logs["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
|
||||
log_entry["rationale"] = rationale
|
||||
log_entry["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
|
||||
|
||||
# Execute
|
||||
self.log_rationale_code_action(rationale, code_action)
|
||||
@@ -1146,7 +1207,7 @@ class ReactCodeAgent(ReactAgent):
|
||||
self.logger.warning("Last output from code snippet:")
|
||||
self.logger.log(32, str(result))
|
||||
observation += "Last output from code snippet:\n" + str(result)[:100000]
|
||||
current_step_logs["observation"] = observation
|
||||
log_entry["observation"] = observation
|
||||
except Exception as e:
|
||||
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
|
||||
if "'dict' object has no attribute 'read'" in str(e):
|
||||
@@ -1156,8 +1217,11 @@ class ReactCodeAgent(ReactAgent):
|
||||
if line[: len("final_answer")] == "final_answer":
|
||||
self.logger.log(33, "Final answer:")
|
||||
self.logger.log(32, result)
|
||||
current_step_logs["final_answer"] = result
|
||||
return current_step_logs
|
||||
log_entry["final_answer"] = result
|
||||
return result
|
||||
|
||||
|
||||
LENGTH_TRUNCATE_REPORTS = 1000
|
||||
|
||||
|
||||
class ManagedAgent:
|
||||
@@ -1200,10 +1264,14 @@ And even if your task resolution is not successful, please return as much contex
|
||||
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
|
||||
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
|
||||
content = message["content"]
|
||||
if len(str(content)) < 1000 or "[FACTS LIST]" in str(content):
|
||||
if len(str(content)) < LENGTH_TRUNCATE_REPORTS or "[FACTS LIST]" in str(content):
|
||||
answer += "\n" + str(content) + "\n---"
|
||||
else:
|
||||
answer += "\n" + str(content)[:1000] + "\n(...Step was truncated because too long)...\n---"
|
||||
answer += (
|
||||
"\n"
|
||||
+ str(content)[:LENGTH_TRUNCATE_REPORTS]
|
||||
+ "\n(...Step was truncated because too long)...\n---"
|
||||
)
|
||||
answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'."
|
||||
return answer
|
||||
else:
|
||||
|
||||
@@ -20,7 +20,12 @@ from typing import Dict, List, Optional
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
from .. import AutoTokenizer
|
||||
from ..pipelines.base import Pipeline
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
@@ -67,46 +72,32 @@ llama_role_conversions = {
|
||||
}
|
||||
|
||||
|
||||
class HfApiEngine:
|
||||
"""A class to interact with Hugging Face's Inference API for language model interaction.
|
||||
class HfEngine:
|
||||
def __init__(self, model_id: Optional[str] = None):
|
||||
self.last_input_token_count = None
|
||||
self.last_output_token_count = None
|
||||
if model_id is None:
|
||||
model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
||||
logger.warning(f"Using default model for token counting: '{model_id}'")
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead.")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct")
|
||||
|
||||
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
|
||||
def get_token_counts(self):
|
||||
return {
|
||||
"input_token_count": self.last_input_token_count,
|
||||
"output_token_count": self.last_output_token_count,
|
||||
}
|
||||
|
||||
Parameters:
|
||||
model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3.1-8B-Instruct"`):
|
||||
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
||||
token (`str`, *optional*):
|
||||
The Hugging Face API token for authentication. If not provided, the class will use the token stored in the Hugging Face CLI configuration.
|
||||
max_tokens (`int`, *optional*, defaults to 1500):
|
||||
The maximum number of tokens allowed in the output.
|
||||
timeout (`int`, *optional*, defaults to 120):
|
||||
Timeout for the API request, in seconds.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If the model name is not provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
token: Optional[str] = None,
|
||||
max_tokens: Optional[int] = 1500,
|
||||
timeout: Optional[int] = 120,
|
||||
def generate(
|
||||
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
|
||||
):
|
||||
"""Initialize the HfApiEngine."""
|
||||
if not model:
|
||||
raise ValueError("Model name must be provided.")
|
||||
|
||||
self.model = model
|
||||
self.client = InferenceClient(self.model, token=token, timeout=timeout)
|
||||
self.max_tokens = max_tokens
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = [],
|
||||
grammar: Optional[str] = None,
|
||||
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
|
||||
) -> str:
|
||||
"""Process the input messages and return the model's response.
|
||||
|
||||
@@ -136,6 +127,57 @@ class HfApiEngine:
|
||||
"Quantum mechanics is the branch of physics that studies..."
|
||||
```
|
||||
"""
|
||||
if not isinstance(messages, List):
|
||||
raise ValueError("Messages should be a list of dictionaries with 'role' and 'content' keys.")
|
||||
if stop_sequences is None:
|
||||
stop_sequences = []
|
||||
response = self.generate(messages, stop_sequences, grammar)
|
||||
self.last_input_token_count = len(self.tokenizer.apply_chat_template(messages, tokenize=True))
|
||||
self.last_output_token_count = len(self.tokenizer.encode(response))
|
||||
|
||||
# Remove stop sequences from LLM output
|
||||
for stop_seq in stop_sequences:
|
||||
if response[-len(stop_seq) :] == stop_seq:
|
||||
response = response[: -len(stop_seq)]
|
||||
return response
|
||||
|
||||
|
||||
class HfApiEngine(HfEngine):
|
||||
"""A class to interact with Hugging Face's Inference API for language model interaction.
|
||||
|
||||
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
|
||||
|
||||
Parameters:
|
||||
model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3.1-8B-Instruct"`):
|
||||
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
||||
token (`str`, *optional*):
|
||||
Token used by the Hugging Face API for authentication.
|
||||
If not provided, the class will use the token stored in the Hugging Face CLI configuration.
|
||||
max_tokens (`int`, *optional*, defaults to 1500):
|
||||
The maximum number of tokens allowed in the output.
|
||||
timeout (`int`, *optional*, defaults to 120):
|
||||
Timeout for the API request, in seconds.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If the model name is not provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
token: Optional[str] = None,
|
||||
max_tokens: Optional[int] = 1500,
|
||||
timeout: Optional[int] = 120,
|
||||
):
|
||||
super().__init__(model_id=model)
|
||||
self.model = model
|
||||
self.client = InferenceClient(self.model, token=token, timeout=timeout)
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def generate(
|
||||
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
|
||||
) -> str:
|
||||
# Get clean message list
|
||||
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
||||
|
||||
@@ -148,41 +190,40 @@ class HfApiEngine:
|
||||
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens)
|
||||
|
||||
response = response.choices[0].message.content
|
||||
|
||||
# Remove stop sequences from LLM output
|
||||
for stop_seq in stop_sequences:
|
||||
if response[-len(stop_seq) :] == stop_seq:
|
||||
response = response[: -len(stop_seq)]
|
||||
return response
|
||||
|
||||
|
||||
class TransformersEngine:
|
||||
class TransformersEngine(HfEngine):
|
||||
"""This engine uses a pre-initialized local text-generation pipeline."""
|
||||
|
||||
def __init__(self, pipeline: Pipeline):
|
||||
def __init__(self, pipeline: Pipeline, model_id: Optional[str] = None):
|
||||
super().__init__(model_id)
|
||||
self.pipeline = pipeline
|
||||
|
||||
def __call__(
|
||||
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
|
||||
def generate(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_length: int = 1500,
|
||||
) -> str:
|
||||
# Get clean message list
|
||||
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
||||
|
||||
# Get LLM output
|
||||
if stop_sequences is not None and len(stop_sequences) > 0:
|
||||
stop_strings = stop_sequences
|
||||
else:
|
||||
stop_strings = None
|
||||
|
||||
output = self.pipeline(
|
||||
messages,
|
||||
stop_strings=stop_sequences,
|
||||
max_length=1500,
|
||||
stop_strings=stop_strings,
|
||||
max_length=max_length,
|
||||
tokenizer=self.pipeline.tokenizer,
|
||||
)
|
||||
|
||||
response = output[0]["generated_text"][-1]["content"]
|
||||
|
||||
# Remove stop sequences from LLM output
|
||||
if stop_sequences is not None:
|
||||
for stop_seq in stop_sequences:
|
||||
if response[-len(stop_seq) :] == stop_seq:
|
||||
response = response[: -len(stop_seq)]
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@@ -14,8 +14,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..utils import logging
|
||||
from .agent_types import AgentAudio, AgentImage, AgentText
|
||||
from .agents import ReactAgent
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def pull_message(step_log: dict, test_mode: bool = True):
|
||||
@@ -54,7 +57,7 @@ def pull_message(step_log: dict, test_mode: bool = True):
|
||||
)
|
||||
|
||||
|
||||
def stream_to_gradio(agent: ReactAgent, task: str, test_mode: bool = False, **kwargs):
|
||||
def stream_to_gradio(agent, task: str, test_mode: bool = False, **kwargs):
|
||||
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
||||
|
||||
try:
|
||||
@@ -91,3 +94,24 @@ def stream_to_gradio(agent: ReactAgent, task: str, test_mode: bool = False, **kw
|
||||
)
|
||||
else:
|
||||
yield ChatMessage(role="assistant", content=str(final_answer))
|
||||
|
||||
|
||||
class Monitor:
|
||||
def __init__(self, tracked_llm_engine):
|
||||
self.step_durations = []
|
||||
self.tracked_llm_engine = tracked_llm_engine
|
||||
if getattr(self.tracked_llm_engine, "last_input_token_count", "Not found") != "Not found":
|
||||
self.total_input_token_count = 0
|
||||
self.total_output_token_count = 0
|
||||
|
||||
def update_metrics(self, step_log):
|
||||
step_duration = step_log["step_duration"]
|
||||
self.step_durations.append(step_duration)
|
||||
logger.info(f"Step {len(self.step_durations)}:")
|
||||
logger.info(f"- Time taken: {step_duration:.2f} seconds (valid only if step succeeded)")
|
||||
|
||||
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
|
||||
self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
|
||||
self.total_output_token_count += self.tracked_llm_engine.last_output_token_count
|
||||
logger.info(f"- Input tokens: {self.total_input_token_count}")
|
||||
logger.info(f"- Output tokens: {self.total_output_token_count}")
|
||||
|
||||
@@ -785,21 +785,22 @@ def launch_gradio_demo(tool_class: Tool):
|
||||
def fn(*args, **kwargs):
|
||||
return tool(*args, **kwargs)
|
||||
|
||||
TYPE_TO_COMPONENT_CLASS_MAPPING = {
|
||||
"image": gr.Image,
|
||||
"audio": gr.Audio,
|
||||
"string": gr.Textbox,
|
||||
"integer": gr.Textbox,
|
||||
"number": gr.Textbox,
|
||||
}
|
||||
|
||||
gradio_inputs = []
|
||||
for input_name, input_details in tool_class.inputs.items():
|
||||
input_type = input_details["type"]
|
||||
if input_type == "image":
|
||||
gradio_inputs.append(gr.Image(label=input_name))
|
||||
elif input_type == "audio":
|
||||
gradio_inputs.append(gr.Audio(label=input_name))
|
||||
elif input_type in ["string", "integer", "number"]:
|
||||
gradio_inputs.append(gr.Textbox(label=input_name))
|
||||
else:
|
||||
error_message = f"Input type '{input_type}' not supported."
|
||||
raise ValueError(error_message)
|
||||
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
|
||||
new_component = input_gradio_component_class(label=input_name)
|
||||
gradio_inputs.append(new_component)
|
||||
|
||||
gradio_output = tool_class.output_type
|
||||
assert gradio_output in ["string", "image", "audio"], f"Output type '{gradio_output}' not supported."
|
||||
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[tool_class.output_type]
|
||||
gradio_output = output_gradio_componentclass(label=input_name)
|
||||
|
||||
gr.Interface(
|
||||
fn=fn,
|
||||
|
||||
@@ -21,11 +21,95 @@ from transformers.agents.monitoring import stream_to_gradio
|
||||
|
||||
|
||||
class MonitoringTester(unittest.TestCase):
|
||||
def test_code_agent_metrics(self):
|
||||
class FakeLLMEngine:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
return """
|
||||
Code:
|
||||
```py
|
||||
final_answer('This is the final answer.')
|
||||
```"""
|
||||
|
||||
agent = ReactCodeAgent(
|
||||
tools=[],
|
||||
llm_engine=FakeLLMEngine(),
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 10)
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
||||
|
||||
def test_json_agent_metrics(self):
|
||||
class FakeLLMEngine:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
||||
|
||||
agent = ReactJsonAgent(
|
||||
tools=[],
|
||||
llm_engine=FakeLLMEngine(),
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 10)
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
||||
|
||||
def test_code_agent_metrics_max_iterations(self):
|
||||
class FakeLLMEngine:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
return "Malformed answer"
|
||||
|
||||
agent = ReactCodeAgent(
|
||||
tools=[],
|
||||
llm_engine=FakeLLMEngine(),
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 20)
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 40)
|
||||
|
||||
def test_code_agent_metrics_generation_error(self):
|
||||
class FakeLLMEngine:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
raise AgentError
|
||||
|
||||
agent = ReactCodeAgent(
|
||||
tools=[],
|
||||
llm_engine=FakeLLMEngine(),
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 20)
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 40)
|
||||
|
||||
def test_streaming_agent_text_output(self):
|
||||
def dummy_llm_engine(prompt, **kwargs):
|
||||
return """
|
||||
Code:
|
||||
````
|
||||
```py
|
||||
final_answer('This is the final answer.')
|
||||
```"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user