Add token cost + runtime monitoring to Agent and HfEngine children (#34548)
* Add monitoring to Agent and HfEngine children
This commit is contained in:
@@ -21,11 +21,95 @@ from transformers.agents.monitoring import stream_to_gradio
|
||||
|
||||
|
||||
class MonitoringTester(unittest.TestCase):
|
||||
def test_code_agent_metrics(self):
|
||||
class FakeLLMEngine:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
return """
|
||||
Code:
|
||||
```py
|
||||
final_answer('This is the final answer.')
|
||||
```"""
|
||||
|
||||
agent = ReactCodeAgent(
|
||||
tools=[],
|
||||
llm_engine=FakeLLMEngine(),
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 10)
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
||||
|
||||
def test_json_agent_metrics(self):
|
||||
class FakeLLMEngine:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
||||
|
||||
agent = ReactJsonAgent(
|
||||
tools=[],
|
||||
llm_engine=FakeLLMEngine(),
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 10)
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
||||
|
||||
def test_code_agent_metrics_max_iterations(self):
|
||||
class FakeLLMEngine:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
return "Malformed answer"
|
||||
|
||||
agent = ReactCodeAgent(
|
||||
tools=[],
|
||||
llm_engine=FakeLLMEngine(),
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 20)
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 40)
|
||||
|
||||
def test_code_agent_metrics_generation_error(self):
|
||||
class FakeLLMEngine:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
raise AgentError
|
||||
|
||||
agent = ReactCodeAgent(
|
||||
tools=[],
|
||||
llm_engine=FakeLLMEngine(),
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 20)
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 40)
|
||||
|
||||
def test_streaming_agent_text_output(self):
|
||||
def dummy_llm_engine(prompt, **kwargs):
|
||||
return """
|
||||
Code:
|
||||
````
|
||||
```py
|
||||
final_answer('This is the final answer.')
|
||||
```"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user