Spaces:
Running
Running
import gradio as gr | |
import replicate | |
import os | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import time | |
import tempfile | |
import base64 | |
# Set up Replicate API key from environment variable | |
os.environ['REPLICATE_API_TOKEN'] = os.getenv('REPLICATE_API_TOKEN') | |
def upload_to_imgur(image): | |
""" | |
Upload image to Imgur and return URL | |
Alternative: You can use other services like Cloudinary, imgbb, etc. | |
""" | |
import base64 | |
import json | |
# Convert PIL image to base64 | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
img_base64 = base64.b64encode(buffered.getvalue()).decode() | |
# Imgur API (anonymous upload) | |
headers = { | |
'Authorization': 'Client-ID 0d90e8a3e7d8b4e' # Public client ID for anonymous uploads | |
} | |
response = requests.post( | |
'https://api.imgur.com/3/image', | |
headers=headers, | |
data={'image': img_base64} | |
) | |
if response.status_code == 200: | |
data = response.json() | |
return data['data']['link'] | |
else: | |
raise Exception(f"Failed to upload to Imgur: {response.status_code}") | |
def process_images(prompt, image1, image2=None): | |
""" | |
Process uploaded images with Replicate API | |
""" | |
if not image1: | |
return None, "Please upload at least one image" | |
# Check if API token is set | |
if not os.getenv('REPLICATE_API_TOKEN'): | |
return None, "β οΈ Please set REPLICATE_API_TOKEN environment variable" | |
try: | |
status_message = "π€ Uploading images..." | |
# Upload images to get public URLs | |
image_urls = [] | |
try: | |
# Try to upload to Imgur (or your preferred service) | |
url1 = upload_to_imgur(image1) | |
image_urls.append(url1) | |
if image2: | |
url2 = upload_to_imgur(image2) | |
image_urls.append(url2) | |
except Exception as upload_error: | |
# Fallback: Convert to base64 data URIs | |
buffered1 = BytesIO() | |
image1.save(buffered1, format="PNG") | |
img_base64_1 = base64.b64encode(buffered1.getvalue()).decode() | |
image_urls.append(f"data:image/png;base64,{img_base64_1}") | |
if image2: | |
buffered2 = BytesIO() | |
image2.save(buffered2, format="PNG") | |
img_base64_2 = base64.b64encode(buffered2.getvalue()).decode() | |
image_urls.append(f"data:image/png;base64,{img_base64_2}") | |
status_message = "π¨ Processing with nano-banana model..." | |
# Prepare input matching the exact format from your example | |
input_data = { | |
"prompt": prompt, | |
"image_input": image_urls | |
} | |
# Run the model | |
output = replicate.run( | |
"google/nano-banana", | |
input=input_data | |
) | |
# Handle various output formats | |
output_url = None | |
# Check different possible output formats | |
if hasattr(output, 'url'): | |
output_url = output.url() | |
elif isinstance(output, str): | |
output_url = output | |
elif isinstance(output, list) and len(output) > 0: | |
output_url = output[0] | |
elif hasattr(output, '__iter__'): | |
try: | |
for item in output: | |
if isinstance(item, str) and item.startswith('http'): | |
output_url = item | |
break | |
except: | |
pass | |
if not output_url: | |
return None, f"β Error: No valid output URL found. Response type: {type(output)}" | |
# Download the generated image | |
if hasattr(output, 'read'): | |
# If output has a read method, use it | |
img_data = output.read() | |
img = Image.open(BytesIO(img_data)) | |
else: | |
# Otherwise, download from URL | |
response = requests.get(output_url) | |
if response.status_code == 200: | |
img = Image.open(BytesIO(response.content)) | |
else: | |
return None, f"β Error: Failed to download image (Status: {response.status_code})" | |
return img, f"β Image generated successfully! Output URL: {output_url[:50]}..." | |
except replicate.exceptions.ModelError as e: | |
return None, f"β Model Error: {str(e)}\n\nMake sure 'google/nano-banana' exists and is accessible." | |
except Exception as e: | |
error_msg = str(e) | |
if "not found" in error_msg.lower(): | |
return None, "β Model 'google/nano-banana' not found. Please check:\n1. Model name is correct\n2. Model is public or you have access\n3. Try format: 'owner/model-name'" | |
elif "authentication" in error_msg.lower(): | |
return None, "β Authentication failed. Please check your REPLICATE_API_TOKEN." | |
else: | |
return None, f"β Error: {error_msg}" | |
# Create Gradio interface with gradient theme | |
css = """ | |
.gradio-container { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
font-family: 'Inter', sans-serif; | |
} | |
.gr-button { | |
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); | |
border: none; | |
color: white; | |
font-weight: bold; | |
transition: transform 0.2s; | |
} | |
.gr-button:hover { | |
transform: scale(1.05); | |
box-shadow: 0 10px 20px rgba(0,0,0,0.2); | |
} | |
.gr-input { | |
border-radius: 10px; | |
border: 2px solid rgba(255,255,255,0.3); | |
background: rgba(255,255,255,0.9); | |
} | |
.header-text { | |
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
background-clip: text; | |
font-size: 2.5em; | |
font-weight: bold; | |
text-align: center; | |
margin-bottom: 20px; | |
} | |
.description-text { | |
color: white; | |
text-align: center; | |
font-size: 1.1em; | |
margin-bottom: 30px; | |
text-shadow: 2px 2px 4px rgba(0,0,0,0.2); | |
} | |
""" | |
# Build the Gradio interface | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
gr.HTML(""" | |
<div class="header-text">π¨ AI Image Style Transfer Studio</div> | |
<div class="description-text"> | |
Upload 1-2 images and describe how you want them styled. | |
The AI will create a beautiful transformation! | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### π€ Input Section") | |
prompt = gr.Textbox( | |
label="βοΈ Style Prompt", | |
placeholder="Describe how you want to style your images...", | |
lines=3, | |
value="Make the sheets in the style of the logo. Make the scene natural." | |
) | |
with gr.Row(): | |
image1 = gr.Image( | |
label="Image 1 (Required)", | |
type="pil", | |
height=200 | |
) | |
image2 = gr.Image( | |
label="Image 2 (Optional)", | |
type="pil", | |
height=200 | |
) | |
generate_btn = gr.Button( | |
"π Generate Styled Image", | |
variant="primary", | |
size="lg" | |
) | |
gr.Markdown(""" | |
#### π‘ Tips: | |
- Upload high-quality images for best results | |
- Be specific in your style description | |
- Experiment with different prompts! | |
""") | |
with gr.Column(scale=1): | |
gr.Markdown("### π― Output Section") | |
output_image = gr.Image( | |
label="Generated Image", | |
type="pil", | |
height=400 | |
) | |
status = gr.Textbox( | |
label="Status", | |
interactive=False, | |
lines=2 | |
) | |
# Examples section | |
with gr.Row(): | |
gr.Examples( | |
examples=[ | |
["Transform into watercolor painting style", None, None], | |
["Make it look like a vintage photograph", None, None], | |
["Apply cyberpunk neon style", None, None], | |
["Convert to minimalist line art", None, None], | |
], | |
inputs=[prompt, image1, image2], | |
label="Example Prompts" | |
) | |
# Event handlers | |
generate_btn.click( | |
fn=process_images, | |
inputs=[prompt, image1, image2], | |
outputs=[output_image, status], | |
api_name="generate" | |
) | |
# Additional information | |
gr.Markdown(""" | |
--- | |
### βοΈ Setup Instructions: | |
1. **Set Environment Variable:** | |
```bash | |
export REPLICATE_API_TOKEN="your_api_token_here" | |
``` | |
Get your token from: https://replicate.com/account/api-tokens | |
2. **Install Required Packages:** | |
```bash | |
pip install gradio replicate pillow requests | |
``` | |
3. **Image Hosting Options:** | |
- **Option A**: Uses Imgur for free image hosting (default) | |
- **Option B**: Falls back to base64 data URIs if upload fails | |
- **Option C**: Use your own image hosting service (Cloudinary, S3, etc.) | |
4. **Model Notes:** | |
- Using: `google/nano-banana` model | |
- If this model doesn't exist, try: | |
- `stability-ai/stable-diffusion` | |
- `pharmapsychotic/clip-interrogator` | |
- Check available models at: https://replicate.com/explore | |
### π Security: | |
- API keys are managed through environment variables | |
- Never commit API keys to version control | |
- Consider implementing user authentication for production | |
""") | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) |