openfree commited on
Commit
65f0832
Β·
verified Β·
1 Parent(s): 047bc32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -68
app.py CHANGED
@@ -24,87 +24,84 @@ def process_images(prompt, image1, image2=None):
24
 
25
  try:
26
  import tempfile
27
- import base64
28
 
29
- # Save images temporarily and create data URIs
30
- image_inputs = []
31
 
32
- # Process first image
 
 
 
 
33
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp1:
34
- image1.save(tmp1.name)
 
35
  with open(tmp1.name, 'rb') as f:
36
- img_data = base64.b64encode(f.read()).decode()
37
- # For some models, you might need data URI format
38
- image_inputs.append(f"data:image/png;base64,{img_data}")
39
- os.unlink(tmp1.name) # Clean up temp file
40
 
41
- # Process second image if provided
42
  if image2:
43
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp2:
44
- image2.save(tmp2.name)
45
  with open(tmp2.name, 'rb') as f:
46
- img_data = base64.b64encode(f.read()).decode()
47
- image_inputs.append(f"data:image/png;base64,{img_data}")
48
  os.unlink(tmp2.name)
49
 
50
- status_message = "🎨 Processing your images..."
51
 
52
- # Example using a real Replicate model (stable-diffusion)
53
- # You should replace this with your actual model
54
- try:
55
- # For image-to-image models, use something like:
56
- output = replicate.run(
57
- "stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
58
- input={
59
- "prompt": prompt,
60
- "image": image_inputs[0] if image_inputs else None,
61
- "num_outputs": 1,
62
- "guidance_scale": 7.5,
63
- "num_inference_steps": 50
64
- }
65
- )
66
- except:
67
- # Fallback to a simpler text-to-image if image input fails
68
- output = replicate.run(
69
- "stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
70
- input={
71
- "prompt": prompt,
72
- "num_outputs": 1
73
- }
74
- )
75
 
76
- # Handle different output formats
 
 
 
 
 
 
77
  output_url = None
78
 
79
- if isinstance(output, list) and len(output) > 0:
80
- # If output is a list, take the first item
81
- output_url = output[0]
82
  elif isinstance(output, str):
83
- # If output is already a string URL
84
  output_url = output
85
- elif hasattr(output, 'url'):
86
- # If output has a url method
87
- output_url = output.url()
88
- elif hasattr(output, '__iter__'):
89
- # If output is iterable, try to get first item
90
- try:
91
- output_url = next(iter(output))
92
- except:
93
- pass
94
 
95
  if not output_url:
96
- return None, "❌ Error: No image content found in response"
97
 
98
- # Download and return the generated image
99
- response = requests.get(output_url)
100
- if response.status_code == 200:
101
- img = Image.open(BytesIO(response.content))
102
- return img, "βœ… Image generated successfully!"
103
  else:
104
- return None, f"❌ Error: Failed to download image (Status: {response.status_code})"
 
 
 
 
 
 
 
105
 
106
  except Exception as e:
107
- return None, f"❌ Error: {str(e)}"
 
 
 
 
 
 
 
108
 
109
  # Create Gradio interface with gradient theme
110
  css = """
@@ -238,22 +235,22 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
238
  ```bash
239
  export REPLICATE_API_TOKEN="your_api_token_here"
240
  ```
 
241
 
242
  2. **Install Required Packages:**
243
  ```bash
244
  pip install gradio replicate pillow requests
245
  ```
246
 
247
- 3. **Available Models to Try:**
248
- - `stability-ai/stable-diffusion` - Text to image generation
249
- - `jagilley/controlnet-canny` - Image style transfer with edge detection
250
- - `rossjillian/controlnet` - Image controlled generation
251
- - Replace the model in the code with your preferred model
252
 
253
- 4. **Note:** For production use, you'll need to:
254
- - Implement proper image upload to cloud storage (S3, Cloudinary, etc.)
255
- - Add proper error handling and rate limiting
256
- - Check the specific input format required by your chosen model
257
 
258
  ### πŸ”’ Security:
259
  - API keys are managed through environment variables
 
24
 
25
  try:
26
  import tempfile
27
+ from replicate.client import Client
28
 
29
+ # Initialize Replicate client
30
+ client = Client(api_token=os.getenv('REPLICATE_API_TOKEN'))
31
 
32
+ # Upload images to get URLs
33
+ # Replicate needs actual URLs, so we need to upload the images first
34
+ image_urls = []
35
+
36
+ # Save and create file handles for upload
37
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp1:
38
+ image1.save(tmp1.name, 'PNG')
39
+ # Upload to Replicate (this creates a temporary URL)
40
  with open(tmp1.name, 'rb') as f:
41
+ file_url = client.upload(f)
42
+ image_urls.append(file_url)
43
+ os.unlink(tmp1.name)
 
44
 
 
45
  if image2:
46
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp2:
47
+ image2.save(tmp2.name, 'PNG')
48
  with open(tmp2.name, 'rb') as f:
49
+ file_url = client.upload(f)
50
+ image_urls.append(file_url)
51
  os.unlink(tmp2.name)
52
 
53
+ status_message = "🎨 Processing your images with nano-banana model..."
54
 
55
+ # Prepare input matching the exact format
56
+ input_data = {
57
+ "prompt": prompt,
58
+ "image_input": image_urls # List of URLs
59
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # Run the model
62
+ output = replicate.run(
63
+ "google/nano-banana",
64
+ input=input_data
65
+ )
66
+
67
+ # Handle the output based on the reference code
68
  output_url = None
69
 
70
+ # The output object should have a url() method based on your example
71
+ if hasattr(output, 'url'):
72
+ output_url = output.url()
73
  elif isinstance(output, str):
 
74
  output_url = output
75
+ elif isinstance(output, list) and len(output) > 0:
76
+ output_url = output[0]
 
 
 
 
 
 
 
77
 
78
  if not output_url:
79
+ return None, "❌ Error: No image URL found in response"
80
 
81
+ # Download the generated image
82
+ # Based on your example, we can also use output.read() if available
83
+ if hasattr(output, 'read'):
84
+ img_data = output.read()
85
+ img = Image.open(BytesIO(img_data))
86
  else:
87
+ # Fallback to downloading from URL
88
+ response = requests.get(output_url)
89
+ if response.status_code == 200:
90
+ img = Image.open(BytesIO(response.content))
91
+ else:
92
+ return None, f"❌ Error: Failed to download image (Status: {response.status_code})"
93
+
94
+ return img, "βœ… Image generated successfully with nano-banana!"
95
 
96
  except Exception as e:
97
+ # If the model doesn't exist or other errors, provide helpful message
98
+ error_msg = str(e)
99
+ if "not found" in error_msg.lower():
100
+ return None, "❌ Model 'google/nano-banana' not found. Please check if the model exists on Replicate."
101
+ elif "authentication" in error_msg.lower():
102
+ return None, "❌ Authentication failed. Please check your REPLICATE_API_TOKEN."
103
+ else:
104
+ return None, f"❌ Error: {error_msg}"
105
 
106
  # Create Gradio interface with gradient theme
107
  css = """
 
235
  ```bash
236
  export REPLICATE_API_TOKEN="your_api_token_here"
237
  ```
238
+ Get your token from: https://replicate.com/account/api-tokens
239
 
240
  2. **Install Required Packages:**
241
  ```bash
242
  pip install gradio replicate pillow requests
243
  ```
244
 
245
+ 3. **Model Information:**
246
+ - Using: `google/nano-banana` model
247
+ - Input: `prompt` (text) and `image_input` (list of image URLs)
248
+ - Output: Generated image based on the style transfer
 
249
 
250
+ 4. **Note:**
251
+ - The model requires actual URLs for images
252
+ - Images are temporarily uploaded to Replicate's servers
253
+ - Check if 'google/nano-banana' is available on your Replicate account
254
 
255
  ### πŸ”’ Security:
256
  - API keys are managed through environment variables