File size: 963 Bytes
51cbe83
 
8ade5a8
f428b3b
 
 
 
27c6979
f428b3b
328babc
e4b86f4
328babc
f428b3b
27c6979
4c36c35
f428b3b
 
6dede21
328babc
6dede21
 
328babc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import os

from smolagents import Tool
from huggingface_hub import InferenceClient


class TextToImageTool(Tool):
    description = "This tool creates an image according to a prompt, which is a text description."
    name = "image_generator"
    inputs = {"prompt": {"type": "string", "description": "The image generator prompt. Don't hesitate to add details in the prompt to make the image look better, like 'high-res, photorealistic', etc."},
             "save_path": {"type": "string", "description": "A file path in `/tmp` to save the image to. The file path extenstion should be .png", "nullable": True}
             }
    output_type = "image"
    model_sdxl = "black-forest-labs/FLUX.1-schnell"
    client = InferenceClient(model_sdxl, token=os.environ["HUB_TOKEN"])


    def forward(self, prompt, save_path=None):
        image = self.client.text_to_image(prompt)
        if save_path is not None:
            image.save(save_path)
        return image