akhaliq HF Staff commited on
Commit
01bf5a7
·
verified ·
1 Parent(s): c658133

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -118
app.py CHANGED
@@ -3,142 +3,156 @@ import numpy as np
3
  import random
4
  import os
5
  import tempfile
 
 
6
  from PIL import Image, ImageOps
7
  import pillow_heif # For HEIF/AVIF support
8
  import io
9
- import fal_client
10
- import base64
11
 
12
  # --- Constants ---
13
  MAX_SEED = np.iinfo(np.int32).max
14
 
15
- def load_client():
16
- """Initialize the FAL Client through HF"""
17
- # Register HEIF opener with PIL for AVIF/HEIF support
18
- pillow_heif.register_heif_opener()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Get token from environment variable
21
  hf_token = os.getenv("HF_TOKEN")
22
  if not hf_token:
23
- raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.")
24
 
25
- # Set the HF token for fal_client to use HF routing
26
- os.environ["FAL_KEY"] = hf_token
27
- return True
28
-
29
- def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=None):
30
- """Send request using fal_client"""
31
 
32
- load_client()
 
33
 
34
  if progress_callback:
35
- progress_callback(0.1, "Submitting request...")
36
-
37
- # Convert image bytes to base64
38
- image_base64 = base64.b64encode(image_bytes).decode('utf-8')
39
 
40
  # Create a temporary file for the image
41
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
42
  temp_file.write(image_bytes)
43
- temp_file_path = temp_file.name
44
 
45
- def on_queue_update(update):
46
- if isinstance(update, fal_client.InProgress):
47
- for log in update.logs:
48
- print(f"FAL Log: {log['message']}")
49
- if progress_callback:
50
- progress_callback(0.5, f"Processing: {log['message'][:50]}...")
51
 
52
  try:
53
  if progress_callback:
54
- progress_callback(0.3, "Connecting to FAL API...")
55
-
56
- # Use fal_client.subscribe following the pattern you provided
57
- result = fal_client.subscribe(
58
- "fal-ai/flux-kontext/dev",
59
- arguments={
60
- "prompt": prompt,
61
- "image_url": f"data:image/png;base64,{image_base64}",
62
- "seed": seed,
63
- "guidance_scale": guidance_scale,
64
- "num_inference_steps": steps,
65
- },
66
- with_logs=True,
67
- on_queue_update=on_queue_update,
68
  )
69
 
70
- print(f"FAL Result: {result}")
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  if progress_callback:
73
- progress_callback(0.9, "Processing result...")
74
-
75
- # Handle the result
76
- if isinstance(result, dict):
77
- if 'images' in result and len(result['images']) > 0:
78
- # Get the first image
79
- image_info = result['images'][0]
80
- if isinstance(image_info, dict) and 'url' in image_info:
81
- # Download image from URL
82
- import requests
83
- img_response = requests.get(image_info['url'])
84
- if img_response.status_code == 200:
85
- if progress_callback:
86
- progress_callback(1.0, "Complete!")
87
- return img_response.content
88
- else:
89
- raise gr.Error(f"Failed to download result image: {img_response.status_code}")
90
- elif isinstance(image_info, str):
91
- # Direct URL
92
- import requests
93
- img_response = requests.get(image_info)
94
- if img_response.status_code == 200:
95
- if progress_callback:
96
- progress_callback(1.0, "Complete!")
97
- return img_response.content
98
- elif 'image' in result:
99
- # Single image field
100
- if isinstance(result['image'], dict) and 'url' in result['image']:
101
- import requests
102
- img_response = requests.get(result['image']['url'])
103
- if img_response.status_code == 200:
104
- if progress_callback:
105
- progress_callback(1.0, "Complete!")
106
- return img_response.content
107
- elif isinstance(result['image'], str):
108
- # Could be URL or base64
109
- if result['image'].startswith('http'):
110
- import requests
111
- img_response = requests.get(result['image'])
112
- if img_response.status_code == 200:
113
- if progress_callback:
114
- progress_callback(1.0, "Complete!")
115
- return img_response.content
116
- else:
117
- # Assume base64
118
- try:
119
- if progress_callback:
120
- progress_callback(1.0, "Complete!")
121
- return base64.b64decode(result['image'])
122
- except:
123
- pass
124
- elif 'url' in result:
125
- # Direct URL in result
126
- import requests
127
- img_response = requests.get(result['url'])
128
- if img_response.status_code == 200:
129
- if progress_callback:
130
- progress_callback(1.0, "Complete!")
131
- return img_response.content
132
-
133
- # If we get here, the result format is unexpected
134
- raise gr.Error(f"Unexpected result format from FAL API: {result}")
135
 
 
 
 
 
 
 
 
 
 
 
 
136
  except Exception as e:
137
- raise gr.Error(f"FAL API error: {str(e)}")
138
  finally:
139
- # Clean up temporary file
140
  try:
141
- os.unlink(temp_file_path)
 
142
  except:
143
  pass
144
 
@@ -159,6 +173,9 @@ def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps,
159
  if files:
160
  print(f"Received image: {files[0]}")
161
  try:
 
 
 
162
  # Try to open and convert the image
163
  input_image = Image.open(files[0])
164
  # Convert to RGB if needed (handles RGBA, P, etc.)
@@ -183,8 +200,8 @@ def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps,
183
  raise gr.Error("This model (FLUX.1 Kontext) requires an input image. Please upload an image to edit.")
184
 
185
  try:
186
- # Make API request
187
- result_bytes = query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=progress)
188
 
189
  # Try to convert response bytes to PIL Image
190
  try:
@@ -192,14 +209,7 @@ def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps,
192
  except Exception as img_error:
193
  print(f"Failed to open image: {img_error}")
194
  print(f"Image bytes type: {type(result_bytes)}, length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}")
195
-
196
- # Try to decode as base64 if direct opening failed
197
- try:
198
- import base64
199
- decoded_bytes = base64.b64decode(result_bytes)
200
- image = Image.open(io.BytesIO(decoded_bytes))
201
- except:
202
- raise gr.Error(f"Could not process API response as image. Response length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}")
203
 
204
  progress(1.0, desc="Complete!")
205
  return gr.Image(value=image)
@@ -219,9 +229,9 @@ steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1)
219
 
220
  demo = gr.ChatInterface(
221
  fn=chat_fn,
222
- title="FLUX.1 Kontext [dev] - FAL Client",
223
  description="""<p style='text-align: center;'>
224
- A simple chat UI for the <b>FLUX.1 Kontext [dev]</b> model using FAL AI client through Hugging Face.
225
  <br>
226
  <b>Upload an image</b> and type your editing instructions (e.g., "Turn the cat into a tiger", "Add a hat").
227
  <br>
@@ -229,7 +239,7 @@ demo = gr.ChatInterface(
229
  <br>
230
  Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
231
  <br>
232
- <b>Note:</b> Uses HF_TOKEN environment variable through HF inference providers.
233
  </p>""",
234
  multimodal=True,
235
  textbox=gr.MultimodalTextbox(
 
3
  import random
4
  import os
5
  import tempfile
6
+ import subprocess
7
+ import json
8
  from PIL import Image, ImageOps
9
  import pillow_heif # For HEIF/AVIF support
10
  import io
 
 
11
 
12
  # --- Constants ---
13
  MAX_SEED = np.iinfo(np.int32).max
14
 
15
+ def setup_node_environment():
16
+ """Setup Node.js environment and install required packages"""
17
+ try:
18
+ # Check if node is available
19
+ result = subprocess.run(['node', '--version'], capture_output=True, text=True)
20
+ if result.returncode != 0:
21
+ raise gr.Error("Node.js is not installed. Please install Node.js to use this feature.")
22
+
23
+ # Check if @huggingface/inference is installed, if not install it
24
+ package_check = subprocess.run(['npm', 'list', '@huggingface/inference'], capture_output=True, text=True)
25
+ if package_check.returncode != 0:
26
+ print("Installing @huggingface/inference package...")
27
+ install_result = subprocess.run(['npm', 'install', '@huggingface/inference'], capture_output=True, text=True)
28
+ if install_result.returncode != 0:
29
+ raise gr.Error(f"Failed to install @huggingface/inference: {install_result.stderr}")
30
+
31
+ return True
32
+ except FileNotFoundError:
33
+ raise gr.Error("Node.js or npm not found. Please install Node.js and npm.")
34
+
35
+ def create_js_inference_script(image_path, prompt, hf_token):
36
+ """Create JavaScript inference script"""
37
+ js_code = f"""
38
+ const {{ InferenceClient }} = require("@huggingface/inference");
39
+ const fs = require("fs");
40
+
41
+ async function runInference() {{
42
+ try {{
43
+ const client = new InferenceClient("{hf_token}");
44
+ const data = fs.readFileSync("{image_path}");
45
+
46
+ const image = await client.imageToImage({{
47
+ provider: "replicate",
48
+ model: "black-forest-labs/FLUX.1-Kontext-dev",
49
+ inputs: data,
50
+ parameters: {{ prompt: "{prompt}" }},
51
+ }}, {{
52
+ billTo: "huggingface",
53
+ }});
54
+
55
+ // Convert blob to buffer
56
+ const arrayBuffer = await image.arrayBuffer();
57
+ const buffer = Buffer.from(arrayBuffer);
58
+
59
+ // Output as base64 for Python to read
60
+ const base64 = buffer.toString('base64');
61
+ console.log(JSON.stringify({{
62
+ success: true,
63
+ image_base64: base64,
64
+ content_type: image.type || 'image/jpeg'
65
+ }}));
66
+
67
+ }} catch (error) {{
68
+ console.log(JSON.stringify({{
69
+ success: false,
70
+ error: error.message
71
+ }}));
72
+ process.exit(1);
73
+ }}
74
+ }}
75
+
76
+ runInference();
77
+ """
78
+ return js_code
79
+
80
+ def query_api_js(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=None):
81
+ """Send request using JavaScript HF Inference Client"""
82
 
83
  # Get token from environment variable
84
  hf_token = os.getenv("HF_TOKEN")
85
  if not hf_token:
86
+ raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the environment.")
87
 
88
+ if progress_callback:
89
+ progress_callback(0.1, "Setting up Node.js environment...")
 
 
 
 
90
 
91
+ # Setup Node.js environment
92
+ setup_node_environment()
93
 
94
  if progress_callback:
95
+ progress_callback(0.2, "Preparing image...")
 
 
 
96
 
97
  # Create a temporary file for the image
98
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
99
  temp_file.write(image_bytes)
100
+ temp_image_path = temp_file.name
101
 
102
+ # Create temporary JavaScript file
103
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.js', delete=False) as js_file:
104
+ js_code = create_js_inference_script(temp_image_path, prompt.replace('"', '\\"'), hf_token)
105
+ js_file.write(js_code)
106
+ js_file_path = js_file.name
 
107
 
108
  try:
109
  if progress_callback:
110
+ progress_callback(0.3, "Running JavaScript inference...")
111
+
112
+ # Run the JavaScript code
113
+ result = subprocess.run(
114
+ ['node', js_file_path],
115
+ capture_output=True,
116
+ text=True,
117
+ timeout=300 # 5 minute timeout
 
 
 
 
 
 
118
  )
119
 
120
+ if progress_callback:
121
+ progress_callback(0.8, "Processing result...")
122
+
123
+ if result.returncode != 0:
124
+ raise gr.Error(f"JavaScript inference failed: {result.stderr}")
125
+
126
+ # Parse the JSON output
127
+ try:
128
+ output = json.loads(result.stdout.strip())
129
+ except json.JSONDecodeError:
130
+ raise gr.Error(f"Failed to parse JavaScript output: {result.stdout}")
131
+
132
+ if not output.get('success'):
133
+ raise gr.Error(f"Inference error: {output.get('error', 'Unknown error')}")
134
 
135
  if progress_callback:
136
+ progress_callback(0.9, "Decoding image...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ # Decode base64 image
139
+ import base64
140
+ image_data = base64.b64decode(output['image_base64'])
141
+
142
+ if progress_callback:
143
+ progress_callback(1.0, "Complete!")
144
+
145
+ return image_data
146
+
147
+ except subprocess.TimeoutExpired:
148
+ raise gr.Error("Inference timed out. Please try again.")
149
  except Exception as e:
150
+ raise gr.Error(f"Error running JavaScript inference: {str(e)}")
151
  finally:
152
+ # Clean up temporary files
153
  try:
154
+ os.unlink(temp_image_path)
155
+ os.unlink(js_file_path)
156
  except:
157
  pass
158
 
 
173
  if files:
174
  print(f"Received image: {files[0]}")
175
  try:
176
+ # Register HEIF opener with PIL for AVIF/HEIF support
177
+ pillow_heif.register_heif_opener()
178
+
179
  # Try to open and convert the image
180
  input_image = Image.open(files[0])
181
  # Convert to RGB if needed (handles RGBA, P, etc.)
 
200
  raise gr.Error("This model (FLUX.1 Kontext) requires an input image. Please upload an image to edit.")
201
 
202
  try:
203
+ # Make API request using JavaScript
204
+ result_bytes = query_api_js(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=progress)
205
 
206
  # Try to convert response bytes to PIL Image
207
  try:
 
209
  except Exception as img_error:
210
  print(f"Failed to open image: {img_error}")
211
  print(f"Image bytes type: {type(result_bytes)}, length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}")
212
+ raise gr.Error(f"Could not process API response as image. Response length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}")
 
 
 
 
 
 
 
213
 
214
  progress(1.0, desc="Complete!")
215
  return gr.Image(value=image)
 
229
 
230
  demo = gr.ChatInterface(
231
  fn=chat_fn,
232
+ title="FLUX.1 Kontext [dev] - HF Inference Client (JS)",
233
  description="""<p style='text-align: center;'>
234
+ A simple chat UI for the <b>FLUX.1 Kontext [dev]</b> model using Hugging Face Inference Client via JavaScript.
235
  <br>
236
  <b>Upload an image</b> and type your editing instructions (e.g., "Turn the cat into a tiger", "Add a hat").
237
  <br>
 
239
  <br>
240
  Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
241
  <br>
242
+ <b>Requirements:</b> Node.js and npm must be installed. Uses HF_TOKEN environment variable.
243
  </p>""",
244
  multimodal=True,
245
  textbox=gr.MultimodalTextbox(