Agents: turn any Space into a Tool with Tool.from_space() (#34561)
* Agents: you can now load a Space as a tool
This commit is contained in:
@@ -464,7 +464,7 @@ image = image_generator(prompt=improved_prompt)
|
|||||||
|
|
||||||
قبل إنشاء الصورة أخيرًا:
|
قبل إنشاء الصورة أخيرًا:
|
||||||
|
|
||||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" />
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit_spacesuit_flux.webp" />
|
||||||
|
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> تتطلب gradio-tools إدخالات وإخراجات *نصية* حتى عند العمل مع طرائق مختلفة مثل كائنات الصور والصوت. الإدخالات والإخراجات الصورية والصوتية غير متوافقة حاليًا.
|
> تتطلب gradio-tools إدخالات وإخراجات *نصية* حتى عند العمل مع طرائق مختلفة مثل كائنات الصور والصوت. الإدخالات والإخراجات الصورية والصوتية غير متوافقة حاليًا.
|
||||||
|
|||||||
@@ -123,6 +123,54 @@ from transformers import load_tool, CodeAgent
|
|||||||
model_download_tool = load_tool("m-ric/hf-model-downloads")
|
model_download_tool = load_tool("m-ric/hf-model-downloads")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Import a Space as a tool 🚀
|
||||||
|
|
||||||
|
You can directly import a Space from the Hub as a tool using the [`Tool.from_space`] method!
|
||||||
|
|
||||||
|
You only need to provide the id of the Space on the Hub, its name, and a description that will help you agent understand what the tool does. Under the hood, this will use [`gradio-client`](https://pypi.org/project/gradio-client/) library to call the Space.
|
||||||
|
|
||||||
|
For instance, let's import the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) Space from the Hub and use it to generate an image.
|
||||||
|
|
||||||
|
```
|
||||||
|
from transformers import Tool
|
||||||
|
|
||||||
|
image_generation_tool = Tool.from_space(
|
||||||
|
"black-forest-labs/FLUX.1-dev",
|
||||||
|
name="image_generator",
|
||||||
|
description="Generate an image from a prompt")
|
||||||
|
|
||||||
|
image_generation_tool("A sunny beach")
|
||||||
|
```
|
||||||
|
And voilà, here's your image! 🏖️
|
||||||
|
|
||||||
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/sunny_beach.webp">
|
||||||
|
|
||||||
|
Then you can use this tool just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit` and generate an image of it.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import ReactCodeAgent
|
||||||
|
|
||||||
|
agent = ReactCodeAgent(tools=[image_generation_tool])
|
||||||
|
|
||||||
|
agent.run(
|
||||||
|
"Improve this prompt, then generate an image of it.", prompt='A rabbit wearing a space suit'
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
```text
|
||||||
|
=== Agent thoughts:
|
||||||
|
improved_prompt could be "A bright blue space suit wearing rabbit, on the surface of the moon, under a bright orange sunset, with the Earth visible in the background"
|
||||||
|
|
||||||
|
Now that I have improved the prompt, I can use the image generator tool to generate an image based on this prompt.
|
||||||
|
>>> Agent is executing the code below:
|
||||||
|
image = image_generator(prompt="A bright blue space suit wearing rabbit, on the surface of the moon, under a bright orange sunset, with the Earth visible in the background")
|
||||||
|
final_answer(image)
|
||||||
|
```
|
||||||
|
|
||||||
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit_spacesuit_flux.webp">
|
||||||
|
|
||||||
|
How cool is this? 🤩
|
||||||
|
|
||||||
### Use gradio-tools
|
### Use gradio-tools
|
||||||
|
|
||||||
[gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging
|
[gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging
|
||||||
@@ -140,36 +188,6 @@ gradio_prompt_generator_tool = StableDiffusionPromptGeneratorTool()
|
|||||||
prompt_generator_tool = Tool.from_gradio(gradio_prompt_generator_tool)
|
prompt_generator_tool = Tool.from_gradio(gradio_prompt_generator_tool)
|
||||||
```
|
```
|
||||||
|
|
||||||
Now you can use it just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit`.
|
|
||||||
|
|
||||||
```python
|
|
||||||
image_generation_tool = load_tool('huggingface-tools/text-to-image')
|
|
||||||
agent = CodeAgent(tools=[prompt_generator_tool, image_generation_tool], llm_engine=llm_engine)
|
|
||||||
|
|
||||||
agent.run(
|
|
||||||
"Improve this prompt, then generate an image of it.", prompt='A rabbit wearing a space suit'
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
The model adequately leverages the tool:
|
|
||||||
```text
|
|
||||||
======== New task ========
|
|
||||||
Improve this prompt, then generate an image of it.
|
|
||||||
You have been provided with these initial arguments: {'prompt': 'A rabbit wearing a space suit'}.
|
|
||||||
==== Agent is executing the code below:
|
|
||||||
improved_prompt = StableDiffusionPromptGenerator(query=prompt)
|
|
||||||
while improved_prompt == "QUEUE_FULL":
|
|
||||||
improved_prompt = StableDiffusionPromptGenerator(query=prompt)
|
|
||||||
print(f"The improved prompt is {improved_prompt}.")
|
|
||||||
image = image_generator(prompt=improved_prompt)
|
|
||||||
====
|
|
||||||
```
|
|
||||||
|
|
||||||
Before finally generating the image:
|
|
||||||
|
|
||||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png">
|
|
||||||
|
|
||||||
|
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> gradio-tools require *textual* inputs and outputs even when working with different modalities like image and audio objects. Image and audio inputs and outputs are currently incompatible.
|
> gradio-tools require *textual* inputs and outputs even when working with different modalities like image and audio objects. Image and audio inputs and outputs are currently incompatible.
|
||||||
|
|
||||||
@@ -179,7 +197,7 @@ We love Langchain and think it has a very compelling suite of tools.
|
|||||||
To import a tool from LangChain, use the `from_langchain()` method.
|
To import a tool from LangChain, use the `from_langchain()` method.
|
||||||
|
|
||||||
Here is how you can use it to recreate the intro's search result using a LangChain web search tool.
|
Here is how you can use it to recreate the intro's search result using a LangChain web search tool.
|
||||||
|
This tool will need `pip install google-search-results` to work properly.
|
||||||
```python
|
```python
|
||||||
from langchain.agents import load_tools
|
from langchain.agents import load_tools
|
||||||
from transformers import Tool, ReactCodeAgent
|
from transformers import Tool, ReactCodeAgent
|
||||||
@@ -188,7 +206,7 @@ search_tool = Tool.from_langchain(load_tools(["serpapi"])[0])
|
|||||||
|
|
||||||
agent = ReactCodeAgent(tools=[search_tool])
|
agent = ReactCodeAgent(tools=[search_tool])
|
||||||
|
|
||||||
agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?")
|
agent.run("How many more blocks (also denoted as layers) are in BERT base encoder compared to the encoder from the architecture proposed in Attention is All You Need?")
|
||||||
```
|
```
|
||||||
|
|
||||||
## Display your agent run in a cool Gradio interface
|
## Display your agent run in a cool Gradio interface
|
||||||
|
|||||||
@@ -87,20 +87,22 @@ launch_gradio_demo({class_name})
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def validate_after_init(cls):
|
def validate_after_init(cls, do_validate_forward: bool = True):
|
||||||
original_init = cls.__init__
|
original_init = cls.__init__
|
||||||
|
|
||||||
@wraps(original_init)
|
@wraps(original_init)
|
||||||
def new_init(self, *args, **kwargs):
|
def new_init(self, *args, **kwargs):
|
||||||
original_init(self, *args, **kwargs)
|
original_init(self, *args, **kwargs)
|
||||||
if not isinstance(self, PipelineTool):
|
if not isinstance(self, PipelineTool):
|
||||||
self.validate_arguments()
|
self.validate_arguments(do_validate_forward=do_validate_forward)
|
||||||
|
|
||||||
cls.__init__ = new_init
|
cls.__init__ = new_init
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
@validate_after_init
|
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
|
||||||
|
|
||||||
|
|
||||||
class Tool:
|
class Tool:
|
||||||
"""
|
"""
|
||||||
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
|
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
|
||||||
@@ -131,7 +133,11 @@ class Tool:
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.is_initialized = False
|
self.is_initialized = False
|
||||||
|
|
||||||
def validate_arguments(self):
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
validate_after_init(cls, do_validate_forward=False)
|
||||||
|
|
||||||
|
def validate_arguments(self, do_validate_forward: bool = True):
|
||||||
required_attributes = {
|
required_attributes = {
|
||||||
"description": str,
|
"description": str,
|
||||||
"name": str,
|
"name": str,
|
||||||
@@ -145,21 +151,23 @@ class Tool:
|
|||||||
if not isinstance(attr_value, expected_type):
|
if not isinstance(attr_value, expected_type):
|
||||||
raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.")
|
raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.")
|
||||||
for input_name, input_content in self.inputs.items():
|
for input_name, input_content in self.inputs.items():
|
||||||
assert "type" in input_content, f"Input '{input_name}' should specify a type."
|
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
|
||||||
|
assert (
|
||||||
|
"type" in input_content and "description" in input_content
|
||||||
|
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
||||||
if input_content["type"] not in authorized_types:
|
if input_content["type"] not in authorized_types:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}."
|
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}."
|
||||||
)
|
)
|
||||||
assert "description" in input_content, f"Input '{input_name}' should have a description."
|
|
||||||
|
|
||||||
assert getattr(self, "output_type", None) in authorized_types
|
assert getattr(self, "output_type", None) in authorized_types
|
||||||
|
if do_validate_forward:
|
||||||
if not isinstance(self, PipelineTool):
|
if not isinstance(self, PipelineTool):
|
||||||
signature = inspect.signature(self.forward)
|
signature = inspect.signature(self.forward)
|
||||||
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
|
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
|
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return NotImplemented("Write this method in your subclass of `Tool`.")
|
return NotImplemented("Write this method in your subclass of `Tool`.")
|
||||||
@@ -405,6 +413,58 @@ class Tool:
|
|||||||
repo_type="space",
|
repo_type="space",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_space(space_id, name, description):
|
||||||
|
"""
|
||||||
|
Creates a [`Tool`] from a Space given its id on the Hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
space_id (`str`):
|
||||||
|
The id of the Space on the Hub.
|
||||||
|
name (`str`):
|
||||||
|
The name of the tool.
|
||||||
|
description (`str`):
|
||||||
|
The description of the tool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`Tool`]:
|
||||||
|
The created tool.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
tool = Tool.from_space("black-forest-labs/FLUX.1-schnell", "image-generator", "Generate an image from a prompt")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
from gradio_client import Client
|
||||||
|
|
||||||
|
class SpaceToolWrapper(Tool):
|
||||||
|
def __init__(self, space_id, name, description):
|
||||||
|
self.client = Client(space_id)
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
space_description = self.client.view_api(return_format="dict")["named_endpoints"]
|
||||||
|
route = list(space_description.keys())[0]
|
||||||
|
space_description_route = space_description[route]
|
||||||
|
self.inputs = {}
|
||||||
|
for parameter in space_description_route["parameters"]:
|
||||||
|
if not parameter["parameter_has_default"]:
|
||||||
|
self.inputs[parameter["parameter_name"]] = {
|
||||||
|
"type": parameter["type"]["type"],
|
||||||
|
"description": parameter["python_type"]["description"],
|
||||||
|
}
|
||||||
|
output_component = space_description_route["returns"][0]["component"]
|
||||||
|
if output_component == "Image":
|
||||||
|
self.output_type = "image"
|
||||||
|
elif output_component == "Audio":
|
||||||
|
self.output_type = "audio"
|
||||||
|
else:
|
||||||
|
self.output_type = "any"
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.client.predict(*args, **kwargs)[0] # Usually the first output is the result
|
||||||
|
|
||||||
|
return SpaceToolWrapper(space_id, name, description)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_gradio(gradio_tool):
|
def from_gradio(gradio_tool):
|
||||||
"""
|
"""
|
||||||
@@ -414,16 +474,15 @@ class Tool:
|
|||||||
|
|
||||||
class GradioToolWrapper(Tool):
|
class GradioToolWrapper(Tool):
|
||||||
def __init__(self, _gradio_tool):
|
def __init__(self, _gradio_tool):
|
||||||
super().__init__()
|
|
||||||
self.name = _gradio_tool.name
|
self.name = _gradio_tool.name
|
||||||
self.description = _gradio_tool.description
|
self.description = _gradio_tool.description
|
||||||
self.output_type = "string"
|
self.output_type = "string"
|
||||||
self._gradio_tool = _gradio_tool
|
self._gradio_tool = _gradio_tool
|
||||||
func_args = list(inspect.signature(_gradio_tool.run).parameters.keys())
|
func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
|
||||||
self.inputs = {key: "" for key in func_args}
|
self.inputs = {
|
||||||
|
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
|
||||||
def forward(self, *args, **kwargs):
|
}
|
||||||
return self._gradio_tool.run(*args, **kwargs)
|
self.forward = self._gradio_tool.run
|
||||||
|
|
||||||
return GradioToolWrapper(gradio_tool)
|
return GradioToolWrapper(gradio_tool)
|
||||||
|
|
||||||
@@ -435,10 +494,13 @@ class Tool:
|
|||||||
|
|
||||||
class LangChainToolWrapper(Tool):
|
class LangChainToolWrapper(Tool):
|
||||||
def __init__(self, _langchain_tool):
|
def __init__(self, _langchain_tool):
|
||||||
super().__init__()
|
|
||||||
self.name = _langchain_tool.name.lower()
|
self.name = _langchain_tool.name.lower()
|
||||||
self.description = _langchain_tool.description
|
self.description = _langchain_tool.description
|
||||||
self.inputs = parse_langchain_args(_langchain_tool.args)
|
self.inputs = _langchain_tool.args.copy()
|
||||||
|
for input_content in self.inputs.values():
|
||||||
|
if "title" in input_content:
|
||||||
|
input_content.pop("title")
|
||||||
|
input_content["description"] = ""
|
||||||
self.output_type = "string"
|
self.output_type = "string"
|
||||||
self.langchain_tool = _langchain_tool
|
self.langchain_tool = _langchain_tool
|
||||||
|
|
||||||
@@ -805,15 +867,6 @@ class EndpointClient:
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
def parse_langchain_args(args: Dict[str, str]) -> Dict[str, str]:
|
|
||||||
"""Parse the args attribute of a LangChain tool to create a matching inputs dictionary."""
|
|
||||||
inputs = args.copy()
|
|
||||||
for arg_details in inputs.values():
|
|
||||||
if "title" in arg_details:
|
|
||||||
arg_details.pop("title")
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
|
|
||||||
class ToolCollection:
|
class ToolCollection:
|
||||||
"""
|
"""
|
||||||
Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
|
Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
|
||||||
|
|||||||
Reference in New Issue
Block a user