akhaliq HF Staff commited on
Commit
a3c7c9b
·
verified ·
1 Parent(s): 618f8cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -95
app.py CHANGED
@@ -6,12 +6,14 @@ import tempfile
6
  from PIL import Image, ImageOps
7
  import pillow_heif # For HEIF/AVIF support
8
  import io
 
 
9
 
10
  # --- Constants ---
11
  MAX_SEED = np.iinfo(np.int32).max
12
 
13
  def load_client():
14
- """Initialize the Inference Client"""
15
  # Register HEIF opener with PIL for AVIF/HEIF support
16
  pillow_heif.register_heif_opener()
17
 
@@ -20,121 +22,125 @@ def load_client():
20
  if not hf_token:
21
  raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.")
22
 
23
- return hf_token
 
 
24
 
25
  def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=None):
26
- """Send request to the API using HF Router for fal.ai provider"""
27
- import requests
28
- import json
29
- import base64
30
 
31
- hf_token = load_client()
32
 
33
  if progress_callback:
34
  progress_callback(0.1, "Submitting request...")
35
 
36
- # Use the HF router to access fal.ai provider
37
- url = "https://router.huggingface.co/fal-ai/fal-ai/flux-kontext/dev"
38
- headers = {
39
- "Authorization": f"Bearer {hf_token}",
40
- "X-HF-Bill-To": "huggingface",
41
- "Content-Type": "application/json"
42
- }
43
-
44
- # Convert image to base64
45
  image_base64 = base64.b64encode(image_bytes).decode('utf-8')
46
 
47
- # Fixed payload structure - prompt should be at the top level
48
- payload = {
49
- "prompt": prompt,
50
- "inputs": image_base64,
51
- "seed": seed,
52
- "guidance_scale": guidance_scale,
53
- "num_inference_steps": steps
54
- }
55
 
56
- if progress_callback:
57
- progress_callback(0.3, "Processing request...")
 
 
 
 
58
 
59
  try:
60
- response = requests.post(url, headers=headers, json=payload, timeout=300)
 
61
 
62
- if response.status_code != 200:
63
- raise gr.Error(f"API request failed with status {response.status_code}: {response.text}")
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # Check if response is image bytes or JSON
66
- content_type = response.headers.get('content-type', '').lower()
67
- print(f"Response content type: {content_type}")
68
- print(f"Response length: {len(response.content)}")
69
 
70
- if 'image/' in content_type:
71
- # Direct image response
72
- if progress_callback:
73
- progress_callback(1.0, "Complete!")
74
- return response.content
75
- elif 'application/json' in content_type:
76
- # JSON response - might be queue status or result
77
- try:
78
- json_response = response.json()
79
- print(f"JSON response: {json_response}")
80
-
81
- # Check if it's a queue response
82
- if json_response.get("status") == "IN_QUEUE":
83
- if progress_callback:
84
- progress_callback(0.4, "Request queued, please wait...")
85
- raise gr.Error("Request is being processed. Please try again in a few moments.")
86
-
87
- # Handle immediate completion or result
88
- if 'images' in json_response and len(json_response['images']) > 0:
89
- image_info = json_response['images'][0]
90
- if isinstance(image_info, dict) and 'url' in image_info:
91
- # Download image from URL
92
  if progress_callback:
93
- progress_callback(0.9, "Downloading result...")
94
- img_response = requests.get(image_info['url'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  if img_response.status_code == 200:
96
  if progress_callback:
97
  progress_callback(1.0, "Complete!")
98
  return img_response.content
99
- else:
100
- raise gr.Error(f"Failed to download image: {img_response.status_code}")
101
- elif isinstance(image_info, str):
102
- # Base64 encoded image
103
- if progress_callback:
104
- progress_callback(1.0, "Complete!")
105
- return base64.b64decode(image_info)
106
- elif 'image' in json_response:
107
- # Single image field
 
 
 
 
108
  if progress_callback:
109
  progress_callback(1.0, "Complete!")
110
- return base64.b64decode(json_response['image'])
111
- else:
112
- raise gr.Error(f"Unexpected JSON response format: {json_response}")
113
- except json.JSONDecodeError as e:
114
- raise gr.Error(f"Failed to parse JSON response: {str(e)}")
115
- else:
116
- # Try to treat as image bytes
117
- if len(response.content) > 1000: # Likely an image
118
- if progress_callback:
119
- progress_callback(1.0, "Complete!")
120
- return response.content
121
- else:
122
- # Small response, probably an error
123
- try:
124
- error_text = response.content.decode('utf-8')
125
- raise gr.Error(f"Unexpected response: {error_text[:500]}")
126
- except:
127
- raise gr.Error(f"Unexpected response format. Content length: {len(response.content)}")
128
-
129
- except requests.exceptions.Timeout:
130
- raise gr.Error("Request timed out. Please try again.")
131
- except requests.exceptions.RequestException as e:
132
- raise gr.Error(f"Request failed: {str(e)}")
133
- except gr.Error:
134
- # Re-raise Gradio errors as-is
135
- raise
136
  except Exception as e:
137
- raise gr.Error(f"Unexpected error: {str(e)}")
 
 
 
 
 
 
138
 
139
  # --- Core Inference Function for ChatInterface ---
140
  def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()):
@@ -213,15 +219,17 @@ steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1)
213
 
214
  demo = gr.ChatInterface(
215
  fn=chat_fn,
216
- title="FLUX.1 Kontext [dev] - HF Inference Client",
217
  description="""<p style='text-align: center;'>
218
- A simple chat UI for the <b>FLUX.1 Kontext [dev]</b> model using Hugging Face Inference Client approach.
219
  <br>
220
  <b>Upload an image</b> and type your editing instructions (e.g., "Turn the cat into a tiger", "Add a hat").
221
  <br>
222
  This model specializes in understanding context and making precise edits to your images.
223
  <br>
224
  Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
 
 
225
  </p>""",
226
  multimodal=True,
227
  textbox=gr.MultimodalTextbox(
 
6
  from PIL import Image, ImageOps
7
  import pillow_heif # For HEIF/AVIF support
8
  import io
9
+ import fal_client
10
+ import base64
11
 
12
  # --- Constants ---
13
  MAX_SEED = np.iinfo(np.int32).max
14
 
15
  def load_client():
16
+ """Initialize the FAL Client through HF"""
17
  # Register HEIF opener with PIL for AVIF/HEIF support
18
  pillow_heif.register_heif_opener()
19
 
 
22
  if not hf_token:
23
  raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.")
24
 
25
+ # Set the HF token for fal_client to use HF routing
26
+ os.environ["FAL_KEY"] = hf_token
27
+ return True
28
 
29
  def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=None):
30
+ """Send request using fal_client"""
 
 
 
31
 
32
+ load_client()
33
 
34
  if progress_callback:
35
  progress_callback(0.1, "Submitting request...")
36
 
37
+ # Convert image bytes to base64
 
 
 
 
 
 
 
 
38
  image_base64 = base64.b64encode(image_bytes).decode('utf-8')
39
 
40
+ # Create a temporary file for the image
41
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
42
+ temp_file.write(image_bytes)
43
+ temp_file_path = temp_file.name
 
 
 
 
44
 
45
+ def on_queue_update(update):
46
+ if isinstance(update, fal_client.InProgress):
47
+ for log in update.logs:
48
+ print(f"FAL Log: {log['message']}")
49
+ if progress_callback:
50
+ progress_callback(0.5, f"Processing: {log['message'][:50]}...")
51
 
52
  try:
53
+ if progress_callback:
54
+ progress_callback(0.3, "Connecting to FAL API...")
55
 
56
+ # Use fal_client.subscribe following the pattern you provided
57
+ result = fal_client.subscribe(
58
+ "fal-ai/flux-kontext/dev",
59
+ arguments={
60
+ "prompt": prompt,
61
+ "image_url": f"data:image/png;base64,{image_base64}",
62
+ "seed": seed,
63
+ "guidance_scale": guidance_scale,
64
+ "num_inference_steps": steps,
65
+ },
66
+ with_logs=True,
67
+ on_queue_update=on_queue_update,
68
+ )
69
 
70
+ print(f"FAL Result: {result}")
 
 
 
71
 
72
+ if progress_callback:
73
+ progress_callback(0.9, "Processing result...")
74
+
75
+ # Handle the result
76
+ if isinstance(result, dict):
77
+ if 'images' in result and len(result['images']) > 0:
78
+ # Get the first image
79
+ image_info = result['images'][0]
80
+ if isinstance(image_info, dict) and 'url' in image_info:
81
+ # Download image from URL
82
+ import requests
83
+ img_response = requests.get(image_info['url'])
84
+ if img_response.status_code == 200:
 
 
 
 
 
 
 
 
 
85
  if progress_callback:
86
+ progress_callback(1.0, "Complete!")
87
+ return img_response.content
88
+ else:
89
+ raise gr.Error(f"Failed to download result image: {img_response.status_code}")
90
+ elif isinstance(image_info, str):
91
+ # Direct URL
92
+ import requests
93
+ img_response = requests.get(image_info)
94
+ if img_response.status_code == 200:
95
+ if progress_callback:
96
+ progress_callback(1.0, "Complete!")
97
+ return img_response.content
98
+ elif 'image' in result:
99
+ # Single image field
100
+ if isinstance(result['image'], dict) and 'url' in result['image']:
101
+ import requests
102
+ img_response = requests.get(result['image']['url'])
103
+ if img_response.status_code == 200:
104
+ if progress_callback:
105
+ progress_callback(1.0, "Complete!")
106
+ return img_response.content
107
+ elif isinstance(result['image'], str):
108
+ # Could be URL or base64
109
+ if result['image'].startswith('http'):
110
+ import requests
111
+ img_response = requests.get(result['image'])
112
  if img_response.status_code == 200:
113
  if progress_callback:
114
  progress_callback(1.0, "Complete!")
115
  return img_response.content
116
+ else:
117
+ # Assume base64
118
+ try:
119
+ if progress_callback:
120
+ progress_callback(1.0, "Complete!")
121
+ return base64.b64decode(result['image'])
122
+ except:
123
+ pass
124
+ elif 'url' in result:
125
+ # Direct URL in result
126
+ import requests
127
+ img_response = requests.get(result['url'])
128
+ if img_response.status_code == 200:
129
  if progress_callback:
130
  progress_callback(1.0, "Complete!")
131
+ return img_response.content
132
+
133
+ # If we get here, the result format is unexpected
134
+ raise gr.Error(f"Unexpected result format from FAL API: {result}")
135
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  except Exception as e:
137
+ raise gr.Error(f"FAL API error: {str(e)}")
138
+ finally:
139
+ # Clean up temporary file
140
+ try:
141
+ os.unlink(temp_file_path)
142
+ except:
143
+ pass
144
 
145
  # --- Core Inference Function for ChatInterface ---
146
  def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()):
 
219
 
220
  demo = gr.ChatInterface(
221
  fn=chat_fn,
222
+ title="FLUX.1 Kontext [dev] - FAL Client",
223
  description="""<p style='text-align: center;'>
224
+ A simple chat UI for the <b>FLUX.1 Kontext [dev]</b> model using FAL AI client through Hugging Face.
225
  <br>
226
  <b>Upload an image</b> and type your editing instructions (e.g., "Turn the cat into a tiger", "Add a hat").
227
  <br>
228
  This model specializes in understanding context and making precise edits to your images.
229
  <br>
230
  Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
231
+ <br>
232
+ <b>Note:</b> Uses HF_TOKEN environment variable through HF inference providers.
233
  </p>""",
234
  multimodal=True,
235
  textbox=gr.MultimodalTextbox(