Agents: turn any Space into a Tool with Tool.from_space() (#34561)
* Agents: you can now load a Space as a tool
This commit is contained in:
@@ -464,7 +464,7 @@ image = image_generator(prompt=improved_prompt)
|
||||
|
||||
قبل إنشاء الصورة أخيرًا:
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" />
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit_spacesuit_flux.webp" />
|
||||
|
||||
> [!WARNING]
|
||||
> تتطلب gradio-tools إدخالات وإخراجات *نصية* حتى عند العمل مع طرائق مختلفة مثل كائنات الصور والصوت. الإدخالات والإخراجات الصورية والصوتية غير متوافقة حاليًا.
|
||||
|
||||
@@ -123,6 +123,54 @@ from transformers import load_tool, CodeAgent
|
||||
model_download_tool = load_tool("m-ric/hf-model-downloads")
|
||||
```
|
||||
|
||||
### Import a Space as a tool 🚀
|
||||
|
||||
You can directly import a Space from the Hub as a tool using the [`Tool.from_space`] method!
|
||||
|
||||
You only need to provide the id of the Space on the Hub, its name, and a description that will help you agent understand what the tool does. Under the hood, this will use [`gradio-client`](https://pypi.org/project/gradio-client/) library to call the Space.
|
||||
|
||||
For instance, let's import the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) Space from the Hub and use it to generate an image.
|
||||
|
||||
```
|
||||
from transformers import Tool
|
||||
|
||||
image_generation_tool = Tool.from_space(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
name="image_generator",
|
||||
description="Generate an image from a prompt")
|
||||
|
||||
image_generation_tool("A sunny beach")
|
||||
```
|
||||
And voilà, here's your image! 🏖️
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/sunny_beach.webp">
|
||||
|
||||
Then you can use this tool just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit` and generate an image of it.
|
||||
|
||||
```python
|
||||
from transformers import ReactCodeAgent
|
||||
|
||||
agent = ReactCodeAgent(tools=[image_generation_tool])
|
||||
|
||||
agent.run(
|
||||
"Improve this prompt, then generate an image of it.", prompt='A rabbit wearing a space suit'
|
||||
)
|
||||
```
|
||||
|
||||
```text
|
||||
=== Agent thoughts:
|
||||
improved_prompt could be "A bright blue space suit wearing rabbit, on the surface of the moon, under a bright orange sunset, with the Earth visible in the background"
|
||||
|
||||
Now that I have improved the prompt, I can use the image generator tool to generate an image based on this prompt.
|
||||
>>> Agent is executing the code below:
|
||||
image = image_generator(prompt="A bright blue space suit wearing rabbit, on the surface of the moon, under a bright orange sunset, with the Earth visible in the background")
|
||||
final_answer(image)
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit_spacesuit_flux.webp">
|
||||
|
||||
How cool is this? 🤩
|
||||
|
||||
### Use gradio-tools
|
||||
|
||||
[gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging
|
||||
@@ -140,36 +188,6 @@ gradio_prompt_generator_tool = StableDiffusionPromptGeneratorTool()
|
||||
prompt_generator_tool = Tool.from_gradio(gradio_prompt_generator_tool)
|
||||
```
|
||||
|
||||
Now you can use it just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit`.
|
||||
|
||||
```python
|
||||
image_generation_tool = load_tool('huggingface-tools/text-to-image')
|
||||
agent = CodeAgent(tools=[prompt_generator_tool, image_generation_tool], llm_engine=llm_engine)
|
||||
|
||||
agent.run(
|
||||
"Improve this prompt, then generate an image of it.", prompt='A rabbit wearing a space suit'
|
||||
)
|
||||
```
|
||||
|
||||
The model adequately leverages the tool:
|
||||
```text
|
||||
======== New task ========
|
||||
Improve this prompt, then generate an image of it.
|
||||
You have been provided with these initial arguments: {'prompt': 'A rabbit wearing a space suit'}.
|
||||
==== Agent is executing the code below:
|
||||
improved_prompt = StableDiffusionPromptGenerator(query=prompt)
|
||||
while improved_prompt == "QUEUE_FULL":
|
||||
improved_prompt = StableDiffusionPromptGenerator(query=prompt)
|
||||
print(f"The improved prompt is {improved_prompt}.")
|
||||
image = image_generator(prompt=improved_prompt)
|
||||
====
|
||||
```
|
||||
|
||||
Before finally generating the image:
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png">
|
||||
|
||||
|
||||
> [!WARNING]
|
||||
> gradio-tools require *textual* inputs and outputs even when working with different modalities like image and audio objects. Image and audio inputs and outputs are currently incompatible.
|
||||
|
||||
@@ -179,7 +197,7 @@ We love Langchain and think it has a very compelling suite of tools.
|
||||
To import a tool from LangChain, use the `from_langchain()` method.
|
||||
|
||||
Here is how you can use it to recreate the intro's search result using a LangChain web search tool.
|
||||
|
||||
This tool will need `pip install google-search-results` to work properly.
|
||||
```python
|
||||
from langchain.agents import load_tools
|
||||
from transformers import Tool, ReactCodeAgent
|
||||
@@ -188,7 +206,7 @@ search_tool = Tool.from_langchain(load_tools(["serpapi"])[0])
|
||||
|
||||
agent = ReactCodeAgent(tools=[search_tool])
|
||||
|
||||
agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?")
|
||||
agent.run("How many more blocks (also denoted as layers) are in BERT base encoder compared to the encoder from the architecture proposed in Attention is All You Need?")
|
||||
```
|
||||
|
||||
## Display your agent run in a cool Gradio interface
|
||||
|
||||
@@ -87,20 +87,22 @@ launch_gradio_demo({class_name})
|
||||
"""
|
||||
|
||||
|
||||
def validate_after_init(cls):
|
||||
def validate_after_init(cls, do_validate_forward: bool = True):
|
||||
original_init = cls.__init__
|
||||
|
||||
@wraps(original_init)
|
||||
def new_init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
if not isinstance(self, PipelineTool):
|
||||
self.validate_arguments()
|
||||
self.validate_arguments(do_validate_forward=do_validate_forward)
|
||||
|
||||
cls.__init__ = new_init
|
||||
return cls
|
||||
|
||||
|
||||
@validate_after_init
|
||||
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
|
||||
|
||||
|
||||
class Tool:
|
||||
"""
|
||||
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
|
||||
@@ -131,7 +133,11 @@ class Tool:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.is_initialized = False
|
||||
|
||||
def validate_arguments(self):
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
validate_after_init(cls, do_validate_forward=False)
|
||||
|
||||
def validate_arguments(self, do_validate_forward: bool = True):
|
||||
required_attributes = {
|
||||
"description": str,
|
||||
"name": str,
|
||||
@@ -145,15 +151,17 @@ class Tool:
|
||||
if not isinstance(attr_value, expected_type):
|
||||
raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.")
|
||||
for input_name, input_content in self.inputs.items():
|
||||
assert "type" in input_content, f"Input '{input_name}' should specify a type."
|
||||
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
|
||||
assert (
|
||||
"type" in input_content and "description" in input_content
|
||||
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
||||
if input_content["type"] not in authorized_types:
|
||||
raise Exception(
|
||||
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}."
|
||||
)
|
||||
assert "description" in input_content, f"Input '{input_name}' should have a description."
|
||||
|
||||
assert getattr(self, "output_type", None) in authorized_types
|
||||
|
||||
if do_validate_forward:
|
||||
if not isinstance(self, PipelineTool):
|
||||
signature = inspect.signature(self.forward)
|
||||
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
|
||||
@@ -405,6 +413,58 @@ class Tool:
|
||||
repo_type="space",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_space(space_id, name, description):
|
||||
"""
|
||||
Creates a [`Tool`] from a Space given its id on the Hub.
|
||||
|
||||
Args:
|
||||
space_id (`str`):
|
||||
The id of the Space on the Hub.
|
||||
name (`str`):
|
||||
The name of the tool.
|
||||
description (`str`):
|
||||
The description of the tool.
|
||||
|
||||
Returns:
|
||||
[`Tool`]:
|
||||
The created tool.
|
||||
|
||||
Example:
|
||||
```
|
||||
tool = Tool.from_space("black-forest-labs/FLUX.1-schnell", "image-generator", "Generate an image from a prompt")
|
||||
```
|
||||
"""
|
||||
from gradio_client import Client
|
||||
|
||||
class SpaceToolWrapper(Tool):
|
||||
def __init__(self, space_id, name, description):
|
||||
self.client = Client(space_id)
|
||||
self.name = name
|
||||
self.description = description
|
||||
space_description = self.client.view_api(return_format="dict")["named_endpoints"]
|
||||
route = list(space_description.keys())[0]
|
||||
space_description_route = space_description[route]
|
||||
self.inputs = {}
|
||||
for parameter in space_description_route["parameters"]:
|
||||
if not parameter["parameter_has_default"]:
|
||||
self.inputs[parameter["parameter_name"]] = {
|
||||
"type": parameter["type"]["type"],
|
||||
"description": parameter["python_type"]["description"],
|
||||
}
|
||||
output_component = space_description_route["returns"][0]["component"]
|
||||
if output_component == "Image":
|
||||
self.output_type = "image"
|
||||
elif output_component == "Audio":
|
||||
self.output_type = "audio"
|
||||
else:
|
||||
self.output_type = "any"
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.client.predict(*args, **kwargs)[0] # Usually the first output is the result
|
||||
|
||||
return SpaceToolWrapper(space_id, name, description)
|
||||
|
||||
@staticmethod
|
||||
def from_gradio(gradio_tool):
|
||||
"""
|
||||
@@ -414,16 +474,15 @@ class Tool:
|
||||
|
||||
class GradioToolWrapper(Tool):
|
||||
def __init__(self, _gradio_tool):
|
||||
super().__init__()
|
||||
self.name = _gradio_tool.name
|
||||
self.description = _gradio_tool.description
|
||||
self.output_type = "string"
|
||||
self._gradio_tool = _gradio_tool
|
||||
func_args = list(inspect.signature(_gradio_tool.run).parameters.keys())
|
||||
self.inputs = {key: "" for key in func_args}
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._gradio_tool.run(*args, **kwargs)
|
||||
func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
|
||||
self.inputs = {
|
||||
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
|
||||
}
|
||||
self.forward = self._gradio_tool.run
|
||||
|
||||
return GradioToolWrapper(gradio_tool)
|
||||
|
||||
@@ -435,10 +494,13 @@ class Tool:
|
||||
|
||||
class LangChainToolWrapper(Tool):
|
||||
def __init__(self, _langchain_tool):
|
||||
super().__init__()
|
||||
self.name = _langchain_tool.name.lower()
|
||||
self.description = _langchain_tool.description
|
||||
self.inputs = parse_langchain_args(_langchain_tool.args)
|
||||
self.inputs = _langchain_tool.args.copy()
|
||||
for input_content in self.inputs.values():
|
||||
if "title" in input_content:
|
||||
input_content.pop("title")
|
||||
input_content["description"] = ""
|
||||
self.output_type = "string"
|
||||
self.langchain_tool = _langchain_tool
|
||||
|
||||
@@ -805,15 +867,6 @@ class EndpointClient:
|
||||
return response.json()
|
||||
|
||||
|
||||
def parse_langchain_args(args: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Parse the args attribute of a LangChain tool to create a matching inputs dictionary."""
|
||||
inputs = args.copy()
|
||||
for arg_details in inputs.values():
|
||||
if "title" in arg_details:
|
||||
arg_details.pop("title")
|
||||
return inputs
|
||||
|
||||
|
||||
class ToolCollection:
|
||||
"""
|
||||
Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
|
||||
|
||||
Reference in New Issue
Block a user