Spaces:
Runtime error
Runtime error
| import requests | |
| import json | |
| from ..tool import Tool | |
| import os | |
| from steamship import Block, Steamship | |
| import uuid | |
| from enum import Enum | |
| import re | |
| from IPython import display | |
| from IPython.display import Image | |
| class ModelName(str, Enum): | |
| """Supported Image Models for generation.""" | |
| DALL_E = "dall-e" | |
| STABLE_DIFFUSION = "stable-diffusion" | |
| SUPPORTED_IMAGE_SIZES = { | |
| ModelName.DALL_E: ("256x256", "512x512", "1024x1024"), | |
| ModelName.STABLE_DIFFUSION: ("512x512", "768x768"), | |
| } | |
| def make_image_public(client: Steamship, block: Block) -> str: | |
| """Upload a block to a signed URL and return the public URL.""" | |
| try: | |
| from steamship.data.workspace import SignedUrl | |
| from steamship.utils.signed_urls import upload_to_signed_url | |
| except ImportError: | |
| raise ValueError( | |
| "The make_image_public function requires the steamship" | |
| " package to be installed. Please install steamship" | |
| " with `pip install --upgrade steamship`" | |
| ) | |
| filepath = str(uuid.uuid4()) | |
| signed_url = ( | |
| client.get_workspace() | |
| .create_signed_url( | |
| SignedUrl.Request( | |
| bucket=SignedUrl.Bucket.PLUGIN_DATA, | |
| filepath=filepath, | |
| operation=SignedUrl.Operation.WRITE, | |
| ) | |
| ) | |
| .signed_url | |
| ) | |
| read_signed_url = ( | |
| client.get_workspace() | |
| .create_signed_url( | |
| SignedUrl.Request( | |
| bucket=SignedUrl.Bucket.PLUGIN_DATA, | |
| filepath=filepath, | |
| operation=SignedUrl.Operation.READ, | |
| ) | |
| ) | |
| .signed_url | |
| ) | |
| upload_to_signed_url(signed_url, block.raw()) | |
| return read_signed_url | |
| def show_output(output): | |
| """Display the multi-modal output from the agent.""" | |
| UUID_PATTERN = re.compile( | |
| r"([0-9A-Za-z]{8}-[0-9A-Za-z]{4}-[0-9A-Za-z]{4}-[0-9A-Za-z]{4}-[0-9A-Za-z]{12})" | |
| ) | |
| outputs = UUID_PATTERN.split(output) | |
| outputs = [ | |
| re.sub(r"^\W+", "", el) for el in outputs | |
| ] # Clean trailing and leading non-word characters | |
| for output in outputs: | |
| maybe_block_id = UUID_PATTERN.search(output) | |
| if maybe_block_id: | |
| display(Image(Block.get(Steamship(), _id=maybe_block_id.group()).raw())) | |
| else: | |
| print(output, end="\n\n") | |
| def build_tool(config) -> Tool: | |
| tool = Tool( | |
| "Image Generator", | |
| "Tool that can generate image based on text description.", | |
| name_for_model="Image Generator", | |
| description_for_model=( | |
| "Useful for when you need to generate an image." | |
| "Input: A detailed text-2-image prompt describing an image" | |
| "Output: the UUID of a generated image" | |
| ), | |
| logo_url="https://your-app-url.com/.well-known/logo.png", | |
| contact_email="[email protected]", | |
| legal_info_url="[email protected]", | |
| ) | |
| model_name: ModelName = ModelName.DALL_E # choose model and image size? | |
| size: Optional[str] = "512x512" | |
| return_urls: Optional[bool] = False | |
| steamship_api_key = os.environ.get("STEAMSHIP_API_KEY", "") | |
| if steamship_api_key == "": | |
| raise RuntimeError( | |
| "STEAMSHIP_API_KEY is not provided. Please sign up for a free account at https://steamship.com/account/api, create a new API key, and add it to environment variables." | |
| ) | |
| steamship = Steamship( | |
| api_key=steamship_api_key, | |
| ) | |
| def generate_image(query: str): | |
| """Generate an image.""" | |
| image_generator = steamship.use_plugin( | |
| plugin_handle=model_name.value, config={"n": 1, "size": size} | |
| ) | |
| task = image_generator.generate(text=query, append_output_to_file=True) | |
| task.wait() | |
| blocks = task.output.blocks | |
| output_uiud = blocks[0].id | |
| if len(blocks) > 0: | |
| if return_urls: | |
| output_uiud = make_image_public(steamship, blocks[0]) | |
| # print image? | |
| # show_output(output_uiud) | |
| return output_uiud | |
| raise RuntimeError("Tool unable to generate image!") | |
| return tool | |