Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -23,25 +23,25 @@ def load_client():
|
|
23 |
return hf_token
|
24 |
|
25 |
def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=None):
|
26 |
-
"""Send request to the API using HF
|
27 |
import requests
|
28 |
import json
|
|
|
29 |
|
30 |
hf_token = load_client()
|
31 |
|
32 |
if progress_callback:
|
33 |
progress_callback(0.1, "Submitting request...")
|
34 |
|
35 |
-
#
|
36 |
-
url = "https://
|
37 |
headers = {
|
38 |
"Authorization": f"Bearer {hf_token}",
|
39 |
-
"
|
40 |
-
"
|
41 |
}
|
42 |
|
43 |
# Convert image to base64
|
44 |
-
import base64
|
45 |
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
46 |
|
47 |
payload = {
|
@@ -51,9 +51,6 @@ def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callbac
|
|
51 |
"seed": seed,
|
52 |
"guidance_scale": guidance_scale,
|
53 |
"num_inference_steps": steps
|
54 |
-
},
|
55 |
-
"options": {
|
56 |
-
"wait_for_model": True
|
57 |
}
|
58 |
}
|
59 |
|
@@ -63,35 +60,13 @@ def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callbac
|
|
63 |
try:
|
64 |
response = requests.post(url, headers=headers, json=payload, timeout=300)
|
65 |
|
66 |
-
if response.status_code != 200:
|
67 |
-
# Try alternative approach with fal-ai provider routing
|
68 |
-
alt_url = "https://router.huggingface.co/fal-ai/black-forest-labs/FLUX.1-Kontext-dev"
|
69 |
-
alt_headers = {
|
70 |
-
"Authorization": f"Bearer {hf_token}",
|
71 |
-
"X-HF-Bill-To": "huggingface",
|
72 |
-
"Content-Type": "application/json"
|
73 |
-
}
|
74 |
-
|
75 |
-
alt_payload = {
|
76 |
-
"inputs": image_base64,
|
77 |
-
"parameters": {
|
78 |
-
"prompt": prompt,
|
79 |
-
"seed": seed,
|
80 |
-
"guidance_scale": guidance_scale,
|
81 |
-
"num_inference_steps": steps
|
82 |
-
}
|
83 |
-
}
|
84 |
-
|
85 |
-
if progress_callback:
|
86 |
-
progress_callback(0.5, "Trying alternative routing...")
|
87 |
-
|
88 |
-
response = requests.post(alt_url, headers=alt_headers, json=alt_payload, timeout=300)
|
89 |
-
|
90 |
if response.status_code != 200:
|
91 |
raise gr.Error(f"API request failed with status {response.status_code}: {response.text}")
|
92 |
|
93 |
# Check if response is image bytes or JSON
|
94 |
content_type = response.headers.get('content-type', '').lower()
|
|
|
|
|
95 |
|
96 |
if 'image/' in content_type:
|
97 |
# Direct image response
|
@@ -99,12 +74,18 @@ def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callbac
|
|
99 |
progress_callback(1.0, "Complete!")
|
100 |
return response.content
|
101 |
elif 'application/json' in content_type:
|
102 |
-
# JSON response
|
103 |
try:
|
104 |
json_response = response.json()
|
105 |
print(f"JSON response: {json_response}")
|
106 |
|
107 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
if 'images' in json_response and len(json_response['images']) > 0:
|
109 |
image_info = json_response['images'][0]
|
110 |
if isinstance(image_info, dict) and 'url' in image_info:
|
@@ -112,9 +93,12 @@ def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callbac
|
|
112 |
if progress_callback:
|
113 |
progress_callback(0.9, "Downloading result...")
|
114 |
img_response = requests.get(image_info['url'])
|
115 |
-
if
|
116 |
-
progress_callback
|
117 |
-
|
|
|
|
|
|
|
118 |
elif isinstance(image_info, str):
|
119 |
# Base64 encoded image
|
120 |
if progress_callback:
|
@@ -127,7 +111,7 @@ def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callbac
|
|
127 |
return base64.b64decode(json_response['image'])
|
128 |
else:
|
129 |
raise gr.Error(f"Unexpected JSON response format: {json_response}")
|
130 |
-
except
|
131 |
raise gr.Error(f"Failed to parse JSON response: {str(e)}")
|
132 |
else:
|
133 |
# Try to treat as image bytes
|
@@ -136,12 +120,22 @@ def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callbac
|
|
136 |
progress_callback(1.0, "Complete!")
|
137 |
return response.content
|
138 |
else:
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
except requests.exceptions.Timeout:
|
142 |
raise gr.Error("Request timed out. Please try again.")
|
143 |
except requests.exceptions.RequestException as e:
|
144 |
raise gr.Error(f"Request failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
# --- Core Inference Function for ChatInterface ---
|
147 |
def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()):
|
|
|
23 |
return hf_token
|
24 |
|
25 |
def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=None):
|
26 |
+
"""Send request to the API using HF Router for fal.ai provider"""
|
27 |
import requests
|
28 |
import json
|
29 |
+
import base64
|
30 |
|
31 |
hf_token = load_client()
|
32 |
|
33 |
if progress_callback:
|
34 |
progress_callback(0.1, "Submitting request...")
|
35 |
|
36 |
+
# Use the HF router to access fal.ai provider
|
37 |
+
url = "https://router.huggingface.co/fal-ai/fal-ai/flux-kontext/dev"
|
38 |
headers = {
|
39 |
"Authorization": f"Bearer {hf_token}",
|
40 |
+
"X-HF-Bill-To": "huggingface",
|
41 |
+
"Content-Type": "application/json"
|
42 |
}
|
43 |
|
44 |
# Convert image to base64
|
|
|
45 |
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
46 |
|
47 |
payload = {
|
|
|
51 |
"seed": seed,
|
52 |
"guidance_scale": guidance_scale,
|
53 |
"num_inference_steps": steps
|
|
|
|
|
|
|
54 |
}
|
55 |
}
|
56 |
|
|
|
60 |
try:
|
61 |
response = requests.post(url, headers=headers, json=payload, timeout=300)
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
if response.status_code != 200:
|
64 |
raise gr.Error(f"API request failed with status {response.status_code}: {response.text}")
|
65 |
|
66 |
# Check if response is image bytes or JSON
|
67 |
content_type = response.headers.get('content-type', '').lower()
|
68 |
+
print(f"Response content type: {content_type}")
|
69 |
+
print(f"Response length: {len(response.content)}")
|
70 |
|
71 |
if 'image/' in content_type:
|
72 |
# Direct image response
|
|
|
74 |
progress_callback(1.0, "Complete!")
|
75 |
return response.content
|
76 |
elif 'application/json' in content_type:
|
77 |
+
# JSON response - might be queue status or result
|
78 |
try:
|
79 |
json_response = response.json()
|
80 |
print(f"JSON response: {json_response}")
|
81 |
|
82 |
+
# Check if it's a queue response
|
83 |
+
if json_response.get("status") == "IN_QUEUE":
|
84 |
+
if progress_callback:
|
85 |
+
progress_callback(0.4, "Request queued, please wait...")
|
86 |
+
raise gr.Error("Request is being processed. Please try again in a few moments.")
|
87 |
+
|
88 |
+
# Handle immediate completion or result
|
89 |
if 'images' in json_response and len(json_response['images']) > 0:
|
90 |
image_info = json_response['images'][0]
|
91 |
if isinstance(image_info, dict) and 'url' in image_info:
|
|
|
93 |
if progress_callback:
|
94 |
progress_callback(0.9, "Downloading result...")
|
95 |
img_response = requests.get(image_info['url'])
|
96 |
+
if img_response.status_code == 200:
|
97 |
+
if progress_callback:
|
98 |
+
progress_callback(1.0, "Complete!")
|
99 |
+
return img_response.content
|
100 |
+
else:
|
101 |
+
raise gr.Error(f"Failed to download image: {img_response.status_code}")
|
102 |
elif isinstance(image_info, str):
|
103 |
# Base64 encoded image
|
104 |
if progress_callback:
|
|
|
111 |
return base64.b64decode(json_response['image'])
|
112 |
else:
|
113 |
raise gr.Error(f"Unexpected JSON response format: {json_response}")
|
114 |
+
except json.JSONDecodeError as e:
|
115 |
raise gr.Error(f"Failed to parse JSON response: {str(e)}")
|
116 |
else:
|
117 |
# Try to treat as image bytes
|
|
|
120 |
progress_callback(1.0, "Complete!")
|
121 |
return response.content
|
122 |
else:
|
123 |
+
# Small response, probably an error
|
124 |
+
try:
|
125 |
+
error_text = response.content.decode('utf-8')
|
126 |
+
raise gr.Error(f"Unexpected response: {error_text[:500]}")
|
127 |
+
except:
|
128 |
+
raise gr.Error(f"Unexpected response format. Content length: {len(response.content)}")
|
129 |
|
130 |
except requests.exceptions.Timeout:
|
131 |
raise gr.Error("Request timed out. Please try again.")
|
132 |
except requests.exceptions.RequestException as e:
|
133 |
raise gr.Error(f"Request failed: {str(e)}")
|
134 |
+
except gr.Error:
|
135 |
+
# Re-raise Gradio errors as-is
|
136 |
+
raise
|
137 |
+
except Exception as e:
|
138 |
+
raise gr.Error(f"Unexpected error: {str(e)}")
|
139 |
|
140 |
# --- Core Inference Function for ChatInterface ---
|
141 |
def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()):
|