Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import logging | |
import requests | |
from pathlib import Path | |
from dataclasses import dataclass | |
from typing import Optional, Tuple, List, Dict, Any, Union | |
from io import BytesIO | |
from PIL import Image | |
import gradio as gr | |
from google import genai | |
from google.genai import types | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger('image_transformer') | |
class TransformResult: | |
"""Class to hold the result of an image transformation""" | |
image_path: Optional[str] = None | |
text_output: str = "" | |
success: bool = True | |
error_message: str = "" | |
class ImageTransformer: | |
"""Class to handle image transformation via Gemini API""" | |
def __init__(self, model_name: str = "gemini-2.0-flash-exp"): | |
self.model_name = model_name | |
logger.info(f"ImageTransformer initialized with model: {model_name}") | |
def write_binary_data(self, filepath: str, data: bytes) -> None: | |
"""Write binary data to a file""" | |
try: | |
with open(filepath, "wb") as f: | |
f.write(data) | |
logger.info(f"Successfully wrote data to {filepath}") | |
except Exception as e: | |
logger.error(f"Failed to write data to {filepath}: {e}") | |
raise | |
def initialize_client(self, api_key: str) -> genai.Client: | |
"""Initialize the Gemini API client""" | |
if not api_key or api_key.strip() == "": | |
# Use environment variable if no API key provided | |
api_key = os.environ.get("GEMINI_API_KEY") | |
if not api_key: | |
logger.error("No API key provided and GEMINI_API_KEY not found in environment") | |
raise ValueError("API key is required. Either provide one or set GEMINI_API_KEY environment variable.") | |
logger.info("Initializing Gemini client") | |
return genai.Client(api_key=api_key.strip()) | |
def create_request_content(self, file_data: Dict[str, Any], instruction_text: str) -> List[types.Content]: | |
"""Create the content object for the API request""" | |
logger.info(f"Creating request content with instruction: {instruction_text}") | |
return [ | |
types.Content( | |
role="user", | |
parts=[ | |
types.Part.from_uri( | |
file_uri=file_data["uri"], | |
mime_type=file_data["mime_type"], | |
), | |
types.Part.from_text(text=instruction_text), | |
], | |
), | |
] | |
def create_request_config(self) -> types.GenerateContentConfig: | |
"""Create the configuration for the API request""" | |
logger.info("Creating request configuration") | |
return types.GenerateContentConfig( | |
temperature=1, | |
top_p=0.95, | |
top_k=40, | |
max_output_tokens=8192, | |
response_modalities=["image", "text"], | |
response_mime_type="text/plain", | |
) | |
def transform_image(self, input_image_path: str, instruction: str, api_key: str) -> TransformResult: | |
"""Transform an image based on the given instruction using Gemini API""" | |
result = TransformResult() | |
try: | |
# Initialize client | |
client = self.initialize_client(api_key) | |
# Upload the file | |
logger.info(f"Uploading file: {input_image_path}") | |
uploaded_file = client.files.upload(file=input_image_path) | |
# Create content and configuration for request | |
contents = self.create_request_content( | |
{"uri": uploaded_file.uri, "mime_type": uploaded_file.mime_type}, | |
instruction | |
) | |
config = self.create_request_config() | |
# Create a temporary file for the response | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
output_path = tmp.name | |
logger.info(f"Created temporary output file: {output_path}") | |
# Send request and process response stream | |
logger.info("Sending request to Gemini API") | |
response_stream = client.models.generate_content_stream( | |
model=self.model_name, | |
contents=contents, | |
config=config, | |
) | |
# Process the response stream | |
for chunk in response_stream: | |
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts: | |
continue | |
candidate = chunk.candidates[0].content.parts[0] | |
# Handle image data | |
if candidate.inline_data: | |
logger.info(f"Received image data ({candidate.inline_data.mime_type})") | |
self.write_binary_data(output_path, candidate.inline_data.data) | |
result.image_path = output_path | |
break | |
# Handle text data | |
else: | |
result.text_output += chunk.text + "\n" | |
# Clean up | |
logger.info("Cleanup: removing uploaded file reference") | |
del uploaded_file | |
# If we have text output but no image, log it | |
if not result.image_path and result.text_output: | |
logger.info(f"No image generated. Text output: {result.text_output[:100]}...") | |
return result | |
except Exception as e: | |
logger.error(f"Error in transform_image: {e}") | |
result.success = False | |
result.error_message = str(e) | |
return result | |
def process_request(self, input_image, instruction: str, api_key: str) -> Tuple[List[Image.Image], str]: | |
"""Process a user request to transform an image""" | |
try: | |
# Check inputs | |
if input_image is None: | |
return None, "Please upload an image to transform." | |
if not instruction or instruction.strip() == "": | |
return None, "Please provide transformation instructions." | |
# Handle both uploaded images and URL examples | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
input_path = tmp.name | |
# Check if input_image is a PIL Image or a string (URL) | |
if isinstance(input_image, str) and (input_image.startswith('http://') or input_image.startswith('https://')): | |
# It's a URL from an example | |
import requests | |
from io import BytesIO | |
logger.info(f"Downloading image from URL: {input_image}") | |
response = requests.get(input_image, stream=True, timeout=10) | |
response.raise_for_status() | |
img = Image.open(BytesIO(response.content)) | |
img.save(input_path) | |
logger.info(f"Saved downloaded image to temporary file: {input_path}") | |
else: | |
# It's a PIL Image from user upload | |
input_image.save(input_path) | |
logger.info(f"Saved uploaded image to temporary file: {input_path}") | |
# Transform the image | |
result = self.transform_image(input_path, instruction, api_key) | |
# Handle result | |
if not result.success: | |
return None, f"Error: {result.error_message}" | |
if result.image_path: | |
# Load and convert the result image | |
output_image = Image.open(result.image_path) | |
if output_image.mode == "RGBA": | |
output_image = output_image.convert("RGB") | |
logger.info(f"Successfully processed image: {result.image_path}") | |
return [output_image], "" | |
else: | |
# Return the text response if no image was generated | |
logger.info("No image generated, returning text response") | |
return None, result.text_output or "No output generated. Try adjusting your instructions." | |
except Exception as e: | |
logger.error(f"Error in process_request: {e}") | |
return None, f"Error: {str(e)}" | |
def build_ui() -> gr.Blocks: | |
"""Build the Gradio interface""" | |
logger.info("Building UI") | |
# Create transformer instance | |
transformer = ImageTransformer() | |
# Custom CSS | |
custom_css = """ | |
/* Main theme colors */ | |
:root { | |
--primary-color: #3a506b; | |
--secondary-color: #5bc0be; | |
--accent-color: #ffd166; | |
--background-color: #f8f9fa; | |
--text-color: #1c2541; | |
--border-radius: 8px; | |
--box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1); | |
} | |
/* Global styles */ | |
body { | |
font-family: 'Inter', system-ui, -apple-system, BlinkMacSystemFont, sans-serif; | |
background-color: var(--background-color); | |
color: var(--text-color); | |
} | |
/* Header styling */ | |
.app-header { | |
display: flex; | |
align-items: center; | |
gap: 20px; | |
padding: 16px 24px; | |
background: linear-gradient(135deg, var(--primary-color), #1c2541); | |
color: white; | |
border-radius: var(--border-radius); | |
margin-bottom: 24px; | |
box-shadow: var(--box-shadow); | |
} | |
.app-header img { | |
width: 48px; | |
height: 48px; | |
border-radius: 50%; | |
background-color: white; | |
padding: 6px; | |
} | |
.app-header h1 { | |
margin: 0; | |
font-size: 1.8rem; | |
font-weight: 700; | |
} | |
.app-header p { | |
margin: 4px 0 0 0; | |
opacity: 0.9; | |
font-size: 0.9rem; | |
} | |
.app-header a { | |
color: var(--accent-color); | |
text-decoration: none; | |
transition: opacity 0.2s; | |
} | |
.app-header a:hover { | |
opacity: 0.8; | |
text-decoration: underline; | |
} | |
/* Accordion styling */ | |
.accordion-container { | |
margin-bottom: 20px; | |
border: 1px solid rgba(0, 0, 0, 0.1); | |
border-radius: var(--border-radius); | |
overflow: hidden; | |
} | |
.accordion-header { | |
background-color: var(--primary-color); | |
color: white; | |
padding: 12px 16px; | |
font-weight: 600; | |
} | |
.accordion-content { | |
padding: 16px; | |
background-color: white; | |
} | |
/* Main content area */ | |
.main-container { | |
display: flex; | |
gap: 24px; | |
margin-bottom: 24px; | |
} | |
/* Input column */ | |
.input-column { | |
flex: 1; | |
background-color: white; | |
padding: 20px; | |
border-radius: var(--border-radius); | |
box-shadow: var(--box-shadow); | |
} | |
/* Output column */ | |
.output-column { | |
flex: 1; | |
background-color: white; | |
padding: 20px; | |
border-radius: var(--border-radius); | |
box-shadow: var(--box-shadow); | |
} | |
/* Button styling */ | |
.generate-button { | |
background-color: var(--secondary-color) !important; | |
color: white !important; | |
border: none !important; | |
border-radius: var(--border-radius) !important; | |
padding: 12px 24px !important; | |
font-weight: 600 !important; | |
cursor: pointer !important; | |
transition: background-color 0.2s !important; | |
width: 100% !important; | |
margin-top: 16px !important; | |
} | |
.generate-button:hover { | |
background-color: #4ca8a6 !important; | |
} | |
/* Image upload area */ | |
.image-upload { | |
border: 2px dashed rgba(0, 0, 0, 0.1); | |
border-radius: var(--border-radius); | |
padding: 20px; | |
text-align: center; | |
transition: border-color 0.2s; | |
} | |
.image-upload:hover { | |
border-color: var(--secondary-color); | |
} | |
/* Input fields */ | |
input[type="text"], input[type="password"], textarea { | |
width: 100%; | |
padding: 10px 12px; | |
border: 1px solid rgba(0, 0, 0, 0.1); | |
border-radius: var(--border-radius); | |
margin-bottom: 16px; | |
font-family: inherit; | |
} | |
input[type="text"]:focus, input[type="password"]:focus, textarea:focus { | |
border-color: var(--secondary-color); | |
outline: none; | |
} | |
/* Examples section */ | |
.examples-header { | |
margin: 32px 0 16px 0; | |
font-weight: 600; | |
color: var(--primary-color); | |
} | |
/* Footer */ | |
.app-footer { | |
text-align: center; | |
padding: 16px; | |
margin-top: 32px; | |
color: rgba(0, 0, 0, 0.5); | |
font-size: 0.8rem; | |
} | |
""" | |
# Gradio interface | |
with gr.Blocks(css=custom_css) as app: | |
# Header | |
gr.HTML( | |
""" | |
<div class="app-header"> | |
<div> | |
<img src="https://img.icons8.com/fluency/96/000000/paint-3d.png" alt="App logo"> | |
</div> | |
<div> | |
<h1>ImageWizard</h1> | |
<p>Transform images with AI | <a href="https://aistudio.google.com/apikey">Get API Key</a></p> | |
</div> | |
</div> | |
""" | |
) | |
# API key information | |
with gr.Accordion("🔑 API Key Required", open=True): | |
gr.HTML( | |
""" | |
<div class="accordion-content"> | |
<p><strong>You need a Gemini API key to use this application.</strong></p> | |
<ol> | |
<li>Visit <a href="https://aistudio.google.com/apikey" target="_blank">Google AI Studio</a> to get your free API key</li> | |
<li>Enter the key in the API Key field below</li> | |
<li>Your key is never stored and only sent directly to Google's API</li> | |
</ol> | |
</div> | |
""" | |
) | |
# Usage instructions | |
with gr.Accordion("📝 How To Use", open=False): | |
gr.HTML( | |
""" | |
<div class="accordion-content"> | |
<h3>How to transform your images:</h3> | |
<ol> | |
<li><strong>Upload an Image:</strong> Click the upload area to select an image (PNG or JPG recommended)</li> | |
<li><strong>Enter your API Key:</strong> Paste your Gemini API key in the designated field</li> | |
<li><strong>Write Instructions:</strong> Clearly describe how you want to transform the image</li> | |
<li><strong>Generate:</strong> Click the Transform button and wait for results</li> | |
</ol> | |
<p><strong>Tips for better results:</strong></p> | |
<ul> | |
<li>Be specific with your instructions (e.g., "change the background to a beach scene" rather than "change the background")</li> | |
<li>If you get text instead of an image, try rephrasing your instructions</li> | |
<li>For best results, use images with clear subjects and simple backgrounds</li> | |
</ul> | |
<p><strong>Please Note:</strong> Do not upload or generate inappropriate content</p> | |
</div> | |
""" | |
) | |
# Main container | |
with gr.Row(elem_classes="main-container"): | |
# Input column | |
with gr.Column(elem_classes="input-column"): | |
image_input = gr.Image( | |
type="pil", | |
label="Upload Your Image", | |
image_mode="RGBA", | |
elem_classes="image-upload" | |
) | |
api_key_input = gr.Textbox( | |
lines=1, | |
placeholder="Enter your Gemini API Key here", | |
label="Gemini API Key", | |
type="password" | |
) | |
instruction_input = gr.Textbox( | |
lines=3, | |
placeholder="Describe how you want to transform the image...", | |
label="Transformation Instructions" | |
) | |
transform_btn = gr.Button("Transform Image", variant="primary") | |
# Output column | |
with gr.Column(elem_classes="output-column"): | |
output_gallery = gr.Gallery( | |
label="Transformed Image", | |
elem_classes="gallery-container" | |
) | |
output_text = gr.Textbox( | |
label="Text Output", | |
placeholder="If no image is generated, text output will appear here.", | |
elem_classes="text-output" | |
) | |
# Set up the interaction | |
transform_btn.click( | |
fn=transformer.process_request, | |
inputs=[image_input, instruction_input, api_key_input], | |
outputs=[output_gallery, output_text], | |
) | |
# Examples section | |
gr.Markdown("## Try These Examples", elem_classes="examples-header") | |
# Examples using publicly available images (Pexels, Unsplash, etc.) | |
examples = [ | |
["https://images.pexels.com/photos/268533/pexels-photo-268533.jpeg", "Change this landscape to night time with stars", ""], | |
["https://images.pexels.com/photos/1933873/pexels-photo-1933873.jpeg", "Add text that says 'DREAM BIG' in elegant font", ""], | |
["https://images.pexels.com/photos/1629781/pexels-photo-1629781.jpeg", "Remove the person from this photo", ""], | |
["https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg", "Make this dog look like it's wearing a superhero cape", ""], | |
["https://images.unsplash.com/photo-1555396273-367ea4eb4db5", "Add a neon glow effect around the coffee cup", ""], | |
["https://images.unsplash.com/photo-1501504905252-473c47e087f8", "Make this whiteboard text more legible and colorful", ""], | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[image_input, instruction_input] | |
) | |
# Footer | |
gr.HTML( | |
""" | |
<div style="text-align: center; padding: 16px; margin-top: 32px; color: rgba(0, 0, 0, 0.5); font-size: 0.8rem;"> | |
<p>ImageWizard © 2025 | Powered by Google Gemini and Gradio</p> | |
</div> | |
""" | |
) | |
return app | |
# Main application entry point | |
def main(): | |
logger.info("Starting Image Transformer application") | |
app = build_ui() | |
app.queue(max_size=50).launch() | |
if __name__ == "__main__": | |
main() |