akhaliq HF Staff commited on
Commit
09f3aa3
·
verified ·
1 Parent(s): d6ceac3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -40
app.py CHANGED
@@ -23,25 +23,25 @@ def load_client():
23
  return hf_token
24
 
25
  def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=None):
26
- """Send request to the API using HF Inference Client approach"""
27
  import requests
28
  import json
 
29
 
30
  hf_token = load_client()
31
 
32
  if progress_callback:
33
  progress_callback(0.1, "Submitting request...")
34
 
35
- # Prepare the request data similar to HF Inference Client
36
- url = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-Kontext-dev"
37
  headers = {
38
  "Authorization": f"Bearer {hf_token}",
39
- "Content-Type": "application/json",
40
- "X-Use-Billing": "huggingface"
41
  }
42
 
43
  # Convert image to base64
44
- import base64
45
  image_base64 = base64.b64encode(image_bytes).decode('utf-8')
46
 
47
  payload = {
@@ -51,9 +51,6 @@ def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callbac
51
  "seed": seed,
52
  "guidance_scale": guidance_scale,
53
  "num_inference_steps": steps
54
- },
55
- "options": {
56
- "wait_for_model": True
57
  }
58
  }
59
 
@@ -63,35 +60,13 @@ def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callbac
63
  try:
64
  response = requests.post(url, headers=headers, json=payload, timeout=300)
65
 
66
- if response.status_code != 200:
67
- # Try alternative approach with fal-ai provider routing
68
- alt_url = "https://router.huggingface.co/fal-ai/black-forest-labs/FLUX.1-Kontext-dev"
69
- alt_headers = {
70
- "Authorization": f"Bearer {hf_token}",
71
- "X-HF-Bill-To": "huggingface",
72
- "Content-Type": "application/json"
73
- }
74
-
75
- alt_payload = {
76
- "inputs": image_base64,
77
- "parameters": {
78
- "prompt": prompt,
79
- "seed": seed,
80
- "guidance_scale": guidance_scale,
81
- "num_inference_steps": steps
82
- }
83
- }
84
-
85
- if progress_callback:
86
- progress_callback(0.5, "Trying alternative routing...")
87
-
88
- response = requests.post(alt_url, headers=alt_headers, json=alt_payload, timeout=300)
89
-
90
  if response.status_code != 200:
91
  raise gr.Error(f"API request failed with status {response.status_code}: {response.text}")
92
 
93
  # Check if response is image bytes or JSON
94
  content_type = response.headers.get('content-type', '').lower()
 
 
95
 
96
  if 'image/' in content_type:
97
  # Direct image response
@@ -99,12 +74,18 @@ def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callbac
99
  progress_callback(1.0, "Complete!")
100
  return response.content
101
  elif 'application/json' in content_type:
102
- # JSON response, might contain image URL or base64
103
  try:
104
  json_response = response.json()
105
  print(f"JSON response: {json_response}")
106
 
107
- # Handle different response formats
 
 
 
 
 
 
108
  if 'images' in json_response and len(json_response['images']) > 0:
109
  image_info = json_response['images'][0]
110
  if isinstance(image_info, dict) and 'url' in image_info:
@@ -112,9 +93,12 @@ def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callbac
112
  if progress_callback:
113
  progress_callback(0.9, "Downloading result...")
114
  img_response = requests.get(image_info['url'])
115
- if progress_callback:
116
- progress_callback(1.0, "Complete!")
117
- return img_response.content
 
 
 
118
  elif isinstance(image_info, str):
119
  # Base64 encoded image
120
  if progress_callback:
@@ -127,7 +111,7 @@ def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callbac
127
  return base64.b64decode(json_response['image'])
128
  else:
129
  raise gr.Error(f"Unexpected JSON response format: {json_response}")
130
- except Exception as e:
131
  raise gr.Error(f"Failed to parse JSON response: {str(e)}")
132
  else:
133
  # Try to treat as image bytes
@@ -136,12 +120,22 @@ def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callbac
136
  progress_callback(1.0, "Complete!")
137
  return response.content
138
  else:
139
- raise gr.Error(f"Unexpected response format. Content: {response.content[:500]}")
 
 
 
 
 
140
 
141
  except requests.exceptions.Timeout:
142
  raise gr.Error("Request timed out. Please try again.")
143
  except requests.exceptions.RequestException as e:
144
  raise gr.Error(f"Request failed: {str(e)}")
 
 
 
 
 
145
 
146
  # --- Core Inference Function for ChatInterface ---
147
  def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()):
 
23
  return hf_token
24
 
25
  def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=None):
26
+ """Send request to the API using HF Router for fal.ai provider"""
27
  import requests
28
  import json
29
+ import base64
30
 
31
  hf_token = load_client()
32
 
33
  if progress_callback:
34
  progress_callback(0.1, "Submitting request...")
35
 
36
+ # Use the HF router to access fal.ai provider
37
+ url = "https://router.huggingface.co/fal-ai/fal-ai/flux-kontext/dev"
38
  headers = {
39
  "Authorization": f"Bearer {hf_token}",
40
+ "X-HF-Bill-To": "huggingface",
41
+ "Content-Type": "application/json"
42
  }
43
 
44
  # Convert image to base64
 
45
  image_base64 = base64.b64encode(image_bytes).decode('utf-8')
46
 
47
  payload = {
 
51
  "seed": seed,
52
  "guidance_scale": guidance_scale,
53
  "num_inference_steps": steps
 
 
 
54
  }
55
  }
56
 
 
60
  try:
61
  response = requests.post(url, headers=headers, json=payload, timeout=300)
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if response.status_code != 200:
64
  raise gr.Error(f"API request failed with status {response.status_code}: {response.text}")
65
 
66
  # Check if response is image bytes or JSON
67
  content_type = response.headers.get('content-type', '').lower()
68
+ print(f"Response content type: {content_type}")
69
+ print(f"Response length: {len(response.content)}")
70
 
71
  if 'image/' in content_type:
72
  # Direct image response
 
74
  progress_callback(1.0, "Complete!")
75
  return response.content
76
  elif 'application/json' in content_type:
77
+ # JSON response - might be queue status or result
78
  try:
79
  json_response = response.json()
80
  print(f"JSON response: {json_response}")
81
 
82
+ # Check if it's a queue response
83
+ if json_response.get("status") == "IN_QUEUE":
84
+ if progress_callback:
85
+ progress_callback(0.4, "Request queued, please wait...")
86
+ raise gr.Error("Request is being processed. Please try again in a few moments.")
87
+
88
+ # Handle immediate completion or result
89
  if 'images' in json_response and len(json_response['images']) > 0:
90
  image_info = json_response['images'][0]
91
  if isinstance(image_info, dict) and 'url' in image_info:
 
93
  if progress_callback:
94
  progress_callback(0.9, "Downloading result...")
95
  img_response = requests.get(image_info['url'])
96
+ if img_response.status_code == 200:
97
+ if progress_callback:
98
+ progress_callback(1.0, "Complete!")
99
+ return img_response.content
100
+ else:
101
+ raise gr.Error(f"Failed to download image: {img_response.status_code}")
102
  elif isinstance(image_info, str):
103
  # Base64 encoded image
104
  if progress_callback:
 
111
  return base64.b64decode(json_response['image'])
112
  else:
113
  raise gr.Error(f"Unexpected JSON response format: {json_response}")
114
+ except json.JSONDecodeError as e:
115
  raise gr.Error(f"Failed to parse JSON response: {str(e)}")
116
  else:
117
  # Try to treat as image bytes
 
120
  progress_callback(1.0, "Complete!")
121
  return response.content
122
  else:
123
+ # Small response, probably an error
124
+ try:
125
+ error_text = response.content.decode('utf-8')
126
+ raise gr.Error(f"Unexpected response: {error_text[:500]}")
127
+ except:
128
+ raise gr.Error(f"Unexpected response format. Content length: {len(response.content)}")
129
 
130
  except requests.exceptions.Timeout:
131
  raise gr.Error("Request timed out. Please try again.")
132
  except requests.exceptions.RequestException as e:
133
  raise gr.Error(f"Request failed: {str(e)}")
134
+ except gr.Error:
135
+ # Re-raise Gradio errors as-is
136
+ raise
137
+ except Exception as e:
138
+ raise gr.Error(f"Unexpected error: {str(e)}")
139
 
140
  # --- Core Inference Function for ChatInterface ---
141
  def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()):