Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -2,171 +2,113 @@ import gradio as gr
|
|
2 |
import numpy as np
|
3 |
import random
|
4 |
import os
|
5 |
-
import
|
6 |
-
import requests
|
7 |
-
import time
|
8 |
-
import io
|
9 |
from PIL import Image, ImageOps
|
10 |
import pillow_heif # For HEIF/AVIF support
|
|
|
11 |
|
12 |
# --- Constants ---
|
13 |
MAX_SEED = np.iinfo(np.int32).max
|
14 |
-
API_URL = "https://router.huggingface.co/fal-ai/fal-ai/flux-kontext/dev?_subdomain=queue"
|
15 |
|
16 |
-
def
|
17 |
-
"""
|
|
|
|
|
|
|
|
|
18 |
hf_token = os.getenv("HF_TOKEN")
|
19 |
if not hf_token:
|
20 |
raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.")
|
21 |
|
22 |
-
return
|
23 |
-
"Authorization": f"Bearer {hf_token}",
|
24 |
-
"X-HF-Bill-To": "huggingface"
|
25 |
-
}
|
26 |
|
27 |
-
def query_api(
|
28 |
-
"""Send request to the API
|
29 |
-
|
|
|
30 |
|
31 |
-
|
32 |
-
if "image_bytes" in payload:
|
33 |
-
payload["inputs"] = base64.b64encode(payload["image_bytes"]).decode("utf-8")
|
34 |
-
del payload["image_bytes"]
|
35 |
|
36 |
-
# Submit the job
|
37 |
if progress_callback:
|
38 |
progress_callback(0.1, "Submitting request...")
|
39 |
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
|
|
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
|
51 |
-
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
raise gr.Error("Missing status_url or response_url in queue response")
|
67 |
-
|
68 |
-
# For fal.ai queue, we need to use direct authentication
|
69 |
-
fal_headers = {
|
70 |
-
"Authorization": f"Key {os.getenv('HF_TOKEN')}",
|
71 |
-
"Content-Type": "application/json"
|
72 |
}
|
73 |
-
|
74 |
-
# Poll for completion using the provided URLs
|
75 |
-
max_attempts = 60 # Wait up to 5 minutes
|
76 |
-
attempt = 0
|
77 |
-
|
78 |
-
while attempt < max_attempts:
|
79 |
-
if progress_callback:
|
80 |
-
progress_callback(0.1 + (attempt / max_attempts) * 0.8, f"Processing... (attempt {attempt + 1}/60)")
|
81 |
-
|
82 |
-
time.sleep(5) # Wait 5 seconds between polls
|
83 |
-
|
84 |
-
# Check status using the provided status_url
|
85 |
-
status_response = requests.get(status_url, headers=fal_headers)
|
86 |
-
|
87 |
-
if status_response.status_code != 200:
|
88 |
-
print(f"Status response: {status_response.status_code} - {status_response.text}")
|
89 |
-
attempt += 1
|
90 |
-
continue
|
91 |
-
|
92 |
-
try:
|
93 |
-
status_data = status_response.json()
|
94 |
-
print(f"Status check {attempt + 1}: {status_data}")
|
95 |
-
|
96 |
-
if status_data.get("status") == "COMPLETED":
|
97 |
-
# Job completed, get the result using response_url
|
98 |
-
result_response = requests.get(response_url, headers=fal_headers)
|
99 |
-
|
100 |
-
if result_response.status_code != 200:
|
101 |
-
print(f"Result response: {result_response.status_code} - {result_response.text}")
|
102 |
-
raise gr.Error(f"Failed to get result: {result_response.status_code}")
|
103 |
-
|
104 |
-
# Check if result is direct image bytes or JSON
|
105 |
-
result_content_type = result_response.headers.get('content-type', '').lower()
|
106 |
-
if 'image/' in result_content_type:
|
107 |
-
# Direct image bytes
|
108 |
-
if progress_callback:
|
109 |
-
progress_callback(1.0, "Complete!")
|
110 |
-
return result_response.content
|
111 |
-
else:
|
112 |
-
# Try to parse as JSON for image URL or base64
|
113 |
-
try:
|
114 |
-
result_data = result_response.json()
|
115 |
-
print(f"Result data: {result_data}")
|
116 |
-
|
117 |
-
# Look for images in various formats
|
118 |
-
if 'images' in result_data and len(result_data['images']) > 0:
|
119 |
-
image_info = result_data['images'][0]
|
120 |
-
if isinstance(image_info, dict) and 'url' in image_info:
|
121 |
-
# Download the image
|
122 |
-
if progress_callback:
|
123 |
-
progress_callback(0.9, "Downloading result...")
|
124 |
-
img_response = requests.get(image_info['url'])
|
125 |
-
if progress_callback:
|
126 |
-
progress_callback(1.0, "Complete!")
|
127 |
-
return img_response.content
|
128 |
-
elif isinstance(image_info, str):
|
129 |
-
# Base64 encoded
|
130 |
-
if progress_callback:
|
131 |
-
progress_callback(1.0, "Complete!")
|
132 |
-
return base64.b64decode(image_info)
|
133 |
-
elif 'image' in result_data:
|
134 |
-
# Single image field
|
135 |
-
if isinstance(result_data['image'], str):
|
136 |
-
if progress_callback:
|
137 |
-
progress_callback(1.0, "Complete!")
|
138 |
-
return base64.b64decode(result_data['image'])
|
139 |
-
else:
|
140 |
-
# Maybe it's direct image bytes
|
141 |
-
if progress_callback:
|
142 |
-
progress_callback(1.0, "Complete!")
|
143 |
-
return result_response.content
|
144 |
-
|
145 |
-
except requests.exceptions.JSONDecodeError:
|
146 |
-
# Result might be direct image bytes
|
147 |
-
if progress_callback:
|
148 |
-
progress_callback(1.0, "Complete!")
|
149 |
-
return result_response.content
|
150 |
-
|
151 |
-
elif status_data.get("status") == "FAILED":
|
152 |
-
error_msg = status_data.get("error", "Unknown error")
|
153 |
-
raise gr.Error(f"Job failed: {error_msg}")
|
154 |
-
|
155 |
-
# Still processing, continue polling
|
156 |
-
attempt += 1
|
157 |
-
|
158 |
-
except requests.exceptions.JSONDecodeError:
|
159 |
-
print("Failed to parse status response, continuing...")
|
160 |
-
attempt += 1
|
161 |
-
continue
|
162 |
|
163 |
-
|
|
|
164 |
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
if 'images' in json_response and len(json_response['images']) > 0:
|
168 |
image_info = json_response['images'][0]
|
169 |
if isinstance(image_info, dict) and 'url' in image_info:
|
|
|
170 |
if progress_callback:
|
171 |
progress_callback(0.9, "Downloading result...")
|
172 |
img_response = requests.get(image_info['url'])
|
@@ -179,177 +121,33 @@ def query_api(payload, progress_callback=None):
|
|
179 |
progress_callback(1.0, "Complete!")
|
180 |
return base64.b64decode(image_info)
|
181 |
elif 'image' in json_response:
|
|
|
182 |
if progress_callback:
|
183 |
progress_callback(1.0, "Complete!")
|
184 |
return base64.b64decode(json_response['image'])
|
185 |
else:
|
186 |
-
raise gr.Error(f"
|
187 |
-
|
188 |
-
|
189 |
-
raise gr.Error(f"Unexpected response status: {json_response.get('status', 'unknown')}")
|
190 |
-
|
191 |
-
except requests.exceptions.JSONDecodeError as e:
|
192 |
-
raise gr.Error(f"Failed to parse JSON response: {str(e)}")
|
193 |
-
|
194 |
-
elif 'image/' in content_type:
|
195 |
-
# Response is direct image bytes
|
196 |
-
if progress_callback:
|
197 |
-
progress_callback(1.0, "Complete!")
|
198 |
-
return response.content
|
199 |
-
|
200 |
-
else:
|
201 |
-
# Unknown content type, but try to handle as image bytes
|
202 |
-
# This might be the case where the router returns the image directly
|
203 |
-
if len(response.content) > 1000: # Likely an image if it's substantial
|
204 |
-
if progress_callback:
|
205 |
-
progress_callback(1.0, "Complete!")
|
206 |
-
return response.content
|
207 |
else:
|
208 |
-
#
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
def upload_image_to_fal(image_bytes):
|
216 |
-
"""Upload image to fal.ai and return the URL"""
|
217 |
-
# For now, we'll use base64 data URI as mentioned in the docs
|
218 |
-
# fal.ai supports base64 data URIs for image_url
|
219 |
-
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
220 |
-
# Detect image format
|
221 |
-
try:
|
222 |
-
img = Image.open(io.BytesIO(image_bytes))
|
223 |
-
format_map = {'JPEG': 'jpeg', 'PNG': 'png', 'WEBP': 'webp'}
|
224 |
-
img_format = format_map.get(img.format, 'jpeg')
|
225 |
-
except:
|
226 |
-
img_format = 'jpeg'
|
227 |
-
|
228 |
-
return f"data:image/{img_format};base64,{image_base64}"
|
229 |
-
"""Send request to the API and return response"""
|
230 |
-
hf_headers = get_headers()
|
231 |
-
|
232 |
-
# Submit the job
|
233 |
-
response = requests.post(API_URL, headers=hf_headers, json=payload)
|
234 |
-
|
235 |
-
if response.status_code != 200:
|
236 |
-
raise gr.Error(f"API request failed with status {response.status_code}: {response.text}")
|
237 |
-
|
238 |
-
# Parse the initial response
|
239 |
-
try:
|
240 |
-
json_response = response.json()
|
241 |
-
print(f"Initial response: {json_response}")
|
242 |
-
except:
|
243 |
-
raise gr.Error("Failed to parse initial API response as JSON")
|
244 |
-
|
245 |
-
# Check if job was queued
|
246 |
-
if json_response.get("status") == "IN_QUEUE":
|
247 |
-
status_url = json_response.get("status_url")
|
248 |
-
if not status_url:
|
249 |
-
raise gr.Error("No status URL provided in queue response")
|
250 |
-
|
251 |
-
# For fal.ai endpoints, we need different headers
|
252 |
-
fal_headers = get_fal_headers()
|
253 |
-
|
254 |
-
# Poll for completion
|
255 |
-
max_attempts = 60 # Wait up to 5 minutes (60 * 5 seconds)
|
256 |
-
attempt = 0
|
257 |
-
|
258 |
-
while attempt < max_attempts:
|
259 |
-
if progress_callback:
|
260 |
-
progress_callback(0.1 + (attempt / max_attempts) * 0.8, f"Processing... (attempt {attempt + 1}/60)")
|
261 |
-
|
262 |
-
time.sleep(5) # Wait 5 seconds between polls
|
263 |
-
|
264 |
-
# Check status with fal.ai headers
|
265 |
-
status_response = requests.get(status_url, headers=fal_headers)
|
266 |
-
|
267 |
-
if status_response.status_code != 200:
|
268 |
-
print(f"Status response: {status_response.status_code} - {status_response.text}")
|
269 |
-
raise gr.Error(f"Status check failed: {status_response.status_code}")
|
270 |
-
|
271 |
-
try:
|
272 |
-
status_data = status_response.json()
|
273 |
-
print(f"Status check {attempt + 1}: {status_data}")
|
274 |
-
|
275 |
-
if status_data.get("status") == "COMPLETED":
|
276 |
-
# Job completed, get the result
|
277 |
-
response_url = json_response.get("response_url")
|
278 |
-
if not response_url:
|
279 |
-
raise gr.Error("No response URL provided")
|
280 |
-
|
281 |
-
# Get result with fal.ai headers
|
282 |
-
result_response = requests.get(response_url, headers=fal_headers)
|
283 |
-
|
284 |
-
if result_response.status_code != 200:
|
285 |
-
print(f"Result response: {result_response.status_code} - {result_response.text}")
|
286 |
-
raise gr.Error(f"Failed to get result: {result_response.status_code}")
|
287 |
-
|
288 |
-
# Check if result is JSON with image data
|
289 |
-
try:
|
290 |
-
result_data = result_response.json()
|
291 |
-
print(f"Result data: {result_data}")
|
292 |
-
|
293 |
-
# Look for image in various possible fields
|
294 |
-
if 'images' in result_data and len(result_data['images']) > 0:
|
295 |
-
# Images array with URLs or base64
|
296 |
-
image_data = result_data['images'][0]
|
297 |
-
if isinstance(image_data, dict) and 'url' in image_data:
|
298 |
-
# Image URL - fetch it
|
299 |
-
img_response = requests.get(image_data['url'])
|
300 |
-
return img_response.content
|
301 |
-
elif isinstance(image_data, str):
|
302 |
-
# Assume base64
|
303 |
-
return base64.b64decode(image_data)
|
304 |
-
elif 'image' in result_data:
|
305 |
-
# Single image field
|
306 |
-
if isinstance(result_data['image'], str):
|
307 |
-
return base64.b64decode(result_data['image'])
|
308 |
-
elif 'url' in result_data:
|
309 |
-
# Direct URL
|
310 |
-
img_response = requests.get(result_data['url'])
|
311 |
-
return img_response.content
|
312 |
-
else:
|
313 |
-
raise gr.Error(f"No image found in result: {result_data}")
|
314 |
-
|
315 |
-
except requests.exceptions.JSONDecodeError:
|
316 |
-
# Result might be direct image bytes
|
317 |
-
return result_response.content
|
318 |
-
|
319 |
-
elif status_data.get("status") == "FAILED":
|
320 |
-
error_msg = status_data.get("error", "Unknown error")
|
321 |
-
raise gr.Error(f"Job failed: {error_msg}")
|
322 |
-
|
323 |
-
# Still processing, continue polling
|
324 |
-
attempt += 1
|
325 |
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
raise gr.Error("
|
330 |
-
|
331 |
-
elif json_response.get("status") == "COMPLETED":
|
332 |
-
# Job completed immediately
|
333 |
-
if 'images' in json_response and len(json_response['images']) > 0:
|
334 |
-
image_data = json_response['images'][0]
|
335 |
-
if isinstance(image_data, str):
|
336 |
-
return base64.b64decode(image_data)
|
337 |
-
elif 'image' in json_response:
|
338 |
-
return base64.b64decode(json_response['image'])
|
339 |
-
else:
|
340 |
-
raise gr.Error(f"No image found in immediate response: {json_response}")
|
341 |
-
|
342 |
-
else:
|
343 |
-
raise gr.Error(f"Unexpected response status: {json_response.get('status', 'unknown')}")
|
344 |
|
345 |
# --- Core Inference Function for ChatInterface ---
|
346 |
def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()):
|
347 |
"""
|
348 |
Performs image generation or editing based on user input from the chat interface.
|
349 |
"""
|
350 |
-
# Register HEIF opener with PIL for AVIF/HEIF support
|
351 |
-
pillow_heif.register_heif_opener()
|
352 |
-
|
353 |
prompt = message["text"]
|
354 |
files = message["files"]
|
355 |
|
@@ -359,16 +157,6 @@ def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps,
|
|
359 |
if randomize_seed:
|
360 |
seed = random.randint(0, MAX_SEED)
|
361 |
|
362 |
-
# Prepare the payload for Hugging Face router
|
363 |
-
payload = {
|
364 |
-
"parameters": {
|
365 |
-
"prompt": prompt,
|
366 |
-
"seed": seed,
|
367 |
-
"guidance_scale": guidance_scale,
|
368 |
-
"num_inference_steps": steps
|
369 |
-
}
|
370 |
-
}
|
371 |
-
|
372 |
if files:
|
373 |
print(f"Received image: {files[0]}")
|
374 |
try:
|
@@ -386,35 +174,33 @@ def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps,
|
|
386 |
img_byte_arr.seek(0)
|
387 |
image_bytes = img_byte_arr.getvalue()
|
388 |
|
389 |
-
# Add image bytes to payload - will be converted to base64 in query_api
|
390 |
-
payload["image_bytes"] = image_bytes
|
391 |
-
|
392 |
except Exception as e:
|
393 |
raise gr.Error(f"Could not process the uploaded image: {str(e)}. Please try uploading a different image format (JPEG, PNG, WebP).")
|
394 |
|
395 |
progress(0.1, desc="Processing image...")
|
396 |
else:
|
397 |
-
|
398 |
-
#
|
399 |
-
|
400 |
|
401 |
try:
|
402 |
-
# Make API request
|
403 |
-
|
404 |
|
405 |
# Try to convert response bytes to PIL Image
|
406 |
try:
|
407 |
-
image = Image.open(io.BytesIO(
|
408 |
except Exception as img_error:
|
409 |
print(f"Failed to open image: {img_error}")
|
410 |
-
print(f"Image bytes type: {type(
|
411 |
|
412 |
# Try to decode as base64 if direct opening failed
|
413 |
try:
|
414 |
-
|
|
|
415 |
image = Image.open(io.BytesIO(decoded_bytes))
|
416 |
except:
|
417 |
-
raise gr.Error(f"Could not process API response as image. Response length: {len(
|
418 |
|
419 |
progress(1.0, desc="Complete!")
|
420 |
return gr.Image(value=image)
|
@@ -434,20 +220,20 @@ steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1)
|
|
434 |
|
435 |
demo = gr.ChatInterface(
|
436 |
fn=chat_fn,
|
437 |
-
title="FLUX.1 Kontext [dev] -
|
438 |
description="""<p style='text-align: center;'>
|
439 |
-
A simple chat UI for the <b>FLUX.1 Kontext [dev]</b> model using Hugging Face
|
440 |
<br>
|
441 |
-
|
442 |
<br>
|
443 |
-
|
444 |
<br>
|
445 |
Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
|
446 |
</p>""",
|
447 |
multimodal=True,
|
448 |
textbox=gr.MultimodalTextbox(
|
449 |
file_types=["image"],
|
450 |
-
placeholder="
|
451 |
render=False
|
452 |
),
|
453 |
additional_inputs=[
|
|
|
2 |
import numpy as np
|
3 |
import random
|
4 |
import os
|
5 |
+
import tempfile
|
|
|
|
|
|
|
6 |
from PIL import Image, ImageOps
|
7 |
import pillow_heif # For HEIF/AVIF support
|
8 |
+
import io
|
9 |
|
10 |
# --- Constants ---
|
11 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
12 |
|
13 |
+
def load_client():
|
14 |
+
"""Initialize the Inference Client"""
|
15 |
+
# Register HEIF opener with PIL for AVIF/HEIF support
|
16 |
+
pillow_heif.register_heif_opener()
|
17 |
+
|
18 |
+
# Get token from environment variable
|
19 |
hf_token = os.getenv("HF_TOKEN")
|
20 |
if not hf_token:
|
21 |
raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.")
|
22 |
|
23 |
+
return hf_token
|
|
|
|
|
|
|
24 |
|
25 |
+
def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=None):
|
26 |
+
"""Send request to the API using HF Inference Client approach"""
|
27 |
+
import requests
|
28 |
+
import json
|
29 |
|
30 |
+
hf_token = load_client()
|
|
|
|
|
|
|
31 |
|
|
|
32 |
if progress_callback:
|
33 |
progress_callback(0.1, "Submitting request...")
|
34 |
|
35 |
+
# Prepare the request data similar to HF Inference Client
|
36 |
+
url = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-Kontext-dev"
|
37 |
+
headers = {
|
38 |
+
"Authorization": f"Bearer {hf_token}",
|
39 |
+
"Content-Type": "application/json",
|
40 |
+
"X-Use-Billing": "huggingface"
|
41 |
+
}
|
42 |
|
43 |
+
# Convert image to base64
|
44 |
+
import base64
|
45 |
+
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
46 |
|
47 |
+
payload = {
|
48 |
+
"inputs": image_base64,
|
49 |
+
"parameters": {
|
50 |
+
"prompt": prompt,
|
51 |
+
"seed": seed,
|
52 |
+
"guidance_scale": guidance_scale,
|
53 |
+
"num_inference_steps": steps
|
54 |
+
},
|
55 |
+
"options": {
|
56 |
+
"wait_for_model": True
|
57 |
+
}
|
58 |
+
}
|
59 |
|
60 |
+
if progress_callback:
|
61 |
+
progress_callback(0.3, "Processing request...")
|
62 |
|
63 |
+
try:
|
64 |
+
response = requests.post(url, headers=headers, json=payload, timeout=300)
|
65 |
+
|
66 |
+
if response.status_code != 200:
|
67 |
+
# Try alternative approach with fal-ai provider routing
|
68 |
+
alt_url = "https://router.huggingface.co/fal-ai/black-forest-labs/FLUX.1-Kontext-dev"
|
69 |
+
alt_headers = {
|
70 |
+
"Authorization": f"Bearer {hf_token}",
|
71 |
+
"X-HF-Bill-To": "huggingface",
|
72 |
+
"Content-Type": "application/json"
|
73 |
+
}
|
74 |
|
75 |
+
alt_payload = {
|
76 |
+
"inputs": image_base64,
|
77 |
+
"parameters": {
|
78 |
+
"prompt": prompt,
|
79 |
+
"seed": seed,
|
80 |
+
"guidance_scale": guidance_scale,
|
81 |
+
"num_inference_steps": steps
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
}
|
83 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
if progress_callback:
|
86 |
+
progress_callback(0.5, "Trying alternative routing...")
|
87 |
|
88 |
+
response = requests.post(alt_url, headers=alt_headers, json=alt_payload, timeout=300)
|
89 |
+
|
90 |
+
if response.status_code != 200:
|
91 |
+
raise gr.Error(f"API request failed with status {response.status_code}: {response.text}")
|
92 |
+
|
93 |
+
# Check if response is image bytes or JSON
|
94 |
+
content_type = response.headers.get('content-type', '').lower()
|
95 |
+
|
96 |
+
if 'image/' in content_type:
|
97 |
+
# Direct image response
|
98 |
+
if progress_callback:
|
99 |
+
progress_callback(1.0, "Complete!")
|
100 |
+
return response.content
|
101 |
+
elif 'application/json' in content_type:
|
102 |
+
# JSON response, might contain image URL or base64
|
103 |
+
try:
|
104 |
+
json_response = response.json()
|
105 |
+
print(f"JSON response: {json_response}")
|
106 |
+
|
107 |
+
# Handle different response formats
|
108 |
if 'images' in json_response and len(json_response['images']) > 0:
|
109 |
image_info = json_response['images'][0]
|
110 |
if isinstance(image_info, dict) and 'url' in image_info:
|
111 |
+
# Download image from URL
|
112 |
if progress_callback:
|
113 |
progress_callback(0.9, "Downloading result...")
|
114 |
img_response = requests.get(image_info['url'])
|
|
|
121 |
progress_callback(1.0, "Complete!")
|
122 |
return base64.b64decode(image_info)
|
123 |
elif 'image' in json_response:
|
124 |
+
# Single image field
|
125 |
if progress_callback:
|
126 |
progress_callback(1.0, "Complete!")
|
127 |
return base64.b64decode(json_response['image'])
|
128 |
else:
|
129 |
+
raise gr.Error(f"Unexpected JSON response format: {json_response}")
|
130 |
+
except Exception as e:
|
131 |
+
raise gr.Error(f"Failed to parse JSON response: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
else:
|
133 |
+
# Try to treat as image bytes
|
134 |
+
if len(response.content) > 1000: # Likely an image
|
135 |
+
if progress_callback:
|
136 |
+
progress_callback(1.0, "Complete!")
|
137 |
+
return response.content
|
138 |
+
else:
|
139 |
+
raise gr.Error(f"Unexpected response format. Content: {response.content[:500]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
+
except requests.exceptions.Timeout:
|
142 |
+
raise gr.Error("Request timed out. Please try again.")
|
143 |
+
except requests.exceptions.RequestException as e:
|
144 |
+
raise gr.Error(f"Request failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
# --- Core Inference Function for ChatInterface ---
|
147 |
def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()):
|
148 |
"""
|
149 |
Performs image generation or editing based on user input from the chat interface.
|
150 |
"""
|
|
|
|
|
|
|
151 |
prompt = message["text"]
|
152 |
files = message["files"]
|
153 |
|
|
|
157 |
if randomize_seed:
|
158 |
seed = random.randint(0, MAX_SEED)
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
if files:
|
161 |
print(f"Received image: {files[0]}")
|
162 |
try:
|
|
|
174 |
img_byte_arr.seek(0)
|
175 |
image_bytes = img_byte_arr.getvalue()
|
176 |
|
|
|
|
|
|
|
177 |
except Exception as e:
|
178 |
raise gr.Error(f"Could not process the uploaded image: {str(e)}. Please try uploading a different image format (JPEG, PNG, WebP).")
|
179 |
|
180 |
progress(0.1, desc="Processing image...")
|
181 |
else:
|
182 |
+
# For text-to-image, we need a placeholder image or handle differently
|
183 |
+
# FLUX.1 Kontext is primarily an image-to-image model
|
184 |
+
raise gr.Error("This model (FLUX.1 Kontext) requires an input image. Please upload an image to edit.")
|
185 |
|
186 |
try:
|
187 |
+
# Make API request
|
188 |
+
result_bytes = query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=progress)
|
189 |
|
190 |
# Try to convert response bytes to PIL Image
|
191 |
try:
|
192 |
+
image = Image.open(io.BytesIO(result_bytes))
|
193 |
except Exception as img_error:
|
194 |
print(f"Failed to open image: {img_error}")
|
195 |
+
print(f"Image bytes type: {type(result_bytes)}, length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}")
|
196 |
|
197 |
# Try to decode as base64 if direct opening failed
|
198 |
try:
|
199 |
+
import base64
|
200 |
+
decoded_bytes = base64.b64decode(result_bytes)
|
201 |
image = Image.open(io.BytesIO(decoded_bytes))
|
202 |
except:
|
203 |
+
raise gr.Error(f"Could not process API response as image. Response length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}")
|
204 |
|
205 |
progress(1.0, desc="Complete!")
|
206 |
return gr.Image(value=image)
|
|
|
220 |
|
221 |
demo = gr.ChatInterface(
|
222 |
fn=chat_fn,
|
223 |
+
title="FLUX.1 Kontext [dev] - HF Inference Client",
|
224 |
description="""<p style='text-align: center;'>
|
225 |
+
A simple chat UI for the <b>FLUX.1 Kontext [dev]</b> model using Hugging Face Inference Client approach.
|
226 |
<br>
|
227 |
+
<b>Upload an image</b> and type your editing instructions (e.g., "Turn the cat into a tiger", "Add a hat").
|
228 |
<br>
|
229 |
+
This model specializes in understanding context and making precise edits to your images.
|
230 |
<br>
|
231 |
Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
|
232 |
</p>""",
|
233 |
multimodal=True,
|
234 |
textbox=gr.MultimodalTextbox(
|
235 |
file_types=["image"],
|
236 |
+
placeholder="Upload an image and type your editing instructions...",
|
237 |
render=False
|
238 |
),
|
239 |
additional_inputs=[
|