import gradio as gr import replicate import os from PIL import Image import requests from io import BytesIO import time import tempfile import base64 # Set up Replicate API key from environment variable os.environ['REPLICATE_API_TOKEN'] = os.getenv('REPLICATE_API_TOKEN') def upload_to_imgur(image): """ Upload image to Imgur and return URL Alternative: You can use other services like Cloudinary, imgbb, etc. """ import base64 import json # Convert PIL image to base64 buffered = BytesIO() image.save(buffered, format="PNG") img_base64 = base64.b64encode(buffered.getvalue()).decode() # Imgur API (anonymous upload) headers = { 'Authorization': 'Client-ID 0d90e8a3e7d8b4e' # Public client ID for anonymous uploads } response = requests.post( 'https://api.imgur.com/3/image', headers=headers, data={'image': img_base64} ) if response.status_code == 200: data = response.json() return data['data']['link'] else: raise Exception(f"Failed to upload to Imgur: {response.status_code}") def process_images(prompt, image1, image2=None): """ Process uploaded images with Replicate API """ if not image1: return None, "Please upload at least one image" # Check if API token is set if not os.getenv('REPLICATE_API_TOKEN'): return None, "⚠️ Please set REPLICATE_API_TOKEN environment variable" try: status_message = "📤 Uploading images..." # Upload images to get public URLs image_urls = [] try: # Try to upload to Imgur (or your preferred service) url1 = upload_to_imgur(image1) image_urls.append(url1) if image2: url2 = upload_to_imgur(image2) image_urls.append(url2) except Exception as upload_error: # Fallback: Convert to base64 data URIs buffered1 = BytesIO() image1.save(buffered1, format="PNG") img_base64_1 = base64.b64encode(buffered1.getvalue()).decode() image_urls.append(f"data:image/png;base64,{img_base64_1}") if image2: buffered2 = BytesIO() image2.save(buffered2, format="PNG") img_base64_2 = base64.b64encode(buffered2.getvalue()).decode() image_urls.append(f"data:image/png;base64,{img_base64_2}") status_message = "🎨 Processing with nano-banana model..." # Prepare input matching the exact format from your example input_data = { "prompt": prompt, "image_input": image_urls } # Run the model output = replicate.run( "google/nano-banana", input=input_data ) # Handle various output formats output_url = None # Check different possible output formats if hasattr(output, 'url'): output_url = output.url() elif isinstance(output, str): output_url = output elif isinstance(output, list) and len(output) > 0: output_url = output[0] elif hasattr(output, '__iter__'): try: for item in output: if isinstance(item, str) and item.startswith('http'): output_url = item break except: pass if not output_url: return None, f"❌ Error: No valid output URL found. Response type: {type(output)}" # Download the generated image if hasattr(output, 'read'): # If output has a read method, use it img_data = output.read() img = Image.open(BytesIO(img_data)) else: # Otherwise, download from URL response = requests.get(output_url) if response.status_code == 200: img = Image.open(BytesIO(response.content)) else: return None, f"❌ Error: Failed to download image (Status: {response.status_code})" return img, f"✅ Image generated successfully! Output URL: {output_url[:50]}..." except replicate.exceptions.ModelError as e: return None, f"❌ Model Error: {str(e)}\n\nMake sure 'google/nano-banana' exists and is accessible." except Exception as e: error_msg = str(e) if "not found" in error_msg.lower(): return None, "❌ Model 'google/nano-banana' not found. Please check:\n1. Model name is correct\n2. Model is public or you have access\n3. Try format: 'owner/model-name'" elif "authentication" in error_msg.lower(): return None, "❌ Authentication failed. Please check your REPLICATE_API_TOKEN." else: return None, f"❌ Error: {error_msg}" # Create Gradio interface with gradient theme css = """ .gradio-container { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); font-family: 'Inter', sans-serif; } .gr-button { background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); border: none; color: white; font-weight: bold; transition: transform 0.2s; } .gr-button:hover { transform: scale(1.05); box-shadow: 0 10px 20px rgba(0,0,0,0.2); } .gr-input { border-radius: 10px; border: 2px solid rgba(255,255,255,0.3); background: rgba(255,255,255,0.9); } .header-text { background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; font-size: 2.5em; font-weight: bold; text-align: center; margin-bottom: 20px; } .description-text { color: white; text-align: center; font-size: 1.1em; margin-bottom: 30px; text-shadow: 2px 2px 4px rgba(0,0,0,0.2); } """ # Build the Gradio interface with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: gr.HTML("""