danielkorat commited on
Commit
28823d7
·
verified ·
1 Parent(s): 2d814fa

Update tool.py

Browse files
Files changed (1) hide show
  1. tool.py +6 -5
tool.py CHANGED
@@ -3,20 +3,21 @@ import os
3
  from smolagents import Tool
4
  from huggingface_hub import InferenceClient
5
 
 
6
 
7
  class TextToImageTool(Tool):
8
- description = "This tool creates an image according to a prompt, which is a text description. When using this tool in a code snippet, the last line of the snippet should be the the tool call (so the image is shown on screen)."
 
9
  name = "image_generator"
10
  inputs = {"prompt": {"type": "string", "description": "The image generator prompt. Don't hesitate to add details in the prompt to make the image look better, like 'high-res, photorealistic', etc."},
11
- "save_path": {"type": "string", "description": "A file path in `/tmp` to save the image to. The file path extenstion should be .png", "nullable": True}
12
  }
13
  output_type = "image"
14
  model_sdxl = "black-forest-labs/FLUX.1-schnell"
15
  client = InferenceClient(model_sdxl, token=os.environ["HUB_TOKEN"])
16
 
17
 
18
- def forward(self, prompt, save_path=None):
19
  image = self.client.text_to_image(prompt)
20
- if save_path is not None:
21
- image.save(save_path)
22
  return image
 
3
  from smolagents import Tool
4
  from huggingface_hub import InferenceClient
5
 
6
+ SAVE_PATH = "/tmp/generated_image.png"
7
 
8
  class TextToImageTool(Tool):
9
+ # description = "This tool creates an image according to a prompt, which is a text description. When using this tool in a code snippet, the last line of the snippet should be the the tool call (so the image is shown on screen)."
10
+ description = f"This tool creates an image according to a prompt, which is a text description. The generated image is always saved to `{SAVE_PATH}`."
11
  name = "image_generator"
12
  inputs = {"prompt": {"type": "string", "description": "The image generator prompt. Don't hesitate to add details in the prompt to make the image look better, like 'high-res, photorealistic', etc."},
13
+ # "save_path": {"type": "string", "description": "A file path in `/tmp` to save the image to. The file path extenstion should be .png", "nullable": True}
14
  }
15
  output_type = "image"
16
  model_sdxl = "black-forest-labs/FLUX.1-schnell"
17
  client = InferenceClient(model_sdxl, token=os.environ["HUB_TOKEN"])
18
 
19
 
20
+ def forward(self, prompt):
21
  image = self.client.text_to_image(prompt)
22
+ image.save(SAVE_PATH)
 
23
  return image