akhaliq HF Staff commited on
Commit
9ab45e8
·
verified ·
1 Parent(s): 01bf5a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -189
app.py CHANGED
@@ -1,166 +1,51 @@
1
  import gradio as gr
2
  import numpy as np
 
 
3
  import random
4
  import os
5
  import tempfile
6
- import subprocess
7
- import json
8
  from PIL import Image, ImageOps
9
  import pillow_heif # For HEIF/AVIF support
10
- import io
 
 
11
 
12
  # --- Constants ---
13
  MAX_SEED = np.iinfo(np.int32).max
14
 
15
- def setup_node_environment():
16
- """Setup Node.js environment and install required packages"""
17
- try:
18
- # Check if node is available
19
- result = subprocess.run(['node', '--version'], capture_output=True, text=True)
20
- if result.returncode != 0:
21
- raise gr.Error("Node.js is not installed. Please install Node.js to use this feature.")
22
-
23
- # Check if @huggingface/inference is installed, if not install it
24
- package_check = subprocess.run(['npm', 'list', '@huggingface/inference'], capture_output=True, text=True)
25
- if package_check.returncode != 0:
26
- print("Installing @huggingface/inference package...")
27
- install_result = subprocess.run(['npm', 'install', '@huggingface/inference'], capture_output=True, text=True)
28
- if install_result.returncode != 0:
29
- raise gr.Error(f"Failed to install @huggingface/inference: {install_result.stderr}")
30
-
31
- return True
32
- except FileNotFoundError:
33
- raise gr.Error("Node.js or npm not found. Please install Node.js and npm.")
34
 
35
- def create_js_inference_script(image_path, prompt, hf_token):
36
- """Create JavaScript inference script"""
37
- js_code = f"""
38
- const {{ InferenceClient }} = require("@huggingface/inference");
39
- const fs = require("fs");
40
-
41
- async function runInference() {{
42
- try {{
43
- const client = new InferenceClient("{hf_token}");
44
- const data = fs.readFileSync("{image_path}");
45
-
46
- const image = await client.imageToImage({{
47
- provider: "replicate",
48
- model: "black-forest-labs/FLUX.1-Kontext-dev",
49
- inputs: data,
50
- parameters: {{ prompt: "{prompt}" }},
51
- }}, {{
52
- billTo: "huggingface",
53
- }});
54
-
55
- // Convert blob to buffer
56
- const arrayBuffer = await image.arrayBuffer();
57
- const buffer = Buffer.from(arrayBuffer);
58
-
59
- // Output as base64 for Python to read
60
- const base64 = buffer.toString('base64');
61
- console.log(JSON.stringify({{
62
- success: true,
63
- image_base64: base64,
64
- content_type: image.type || 'image/jpeg'
65
- }}));
66
-
67
- }} catch (error) {{
68
- console.log(JSON.stringify({{
69
- success: false,
70
- error: error.message
71
- }}));
72
- process.exit(1);
73
- }}
74
- }}
75
-
76
- runInference();
77
- """
78
- return js_code
79
-
80
- def query_api_js(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=None):
81
- """Send request using JavaScript HF Inference Client"""
82
-
83
- # Get token from environment variable
84
- hf_token = os.getenv("HF_TOKEN")
85
- if not hf_token:
86
- raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the environment.")
87
-
88
- if progress_callback:
89
- progress_callback(0.1, "Setting up Node.js environment...")
90
-
91
- # Setup Node.js environment
92
- setup_node_environment()
93
-
94
- if progress_callback:
95
- progress_callback(0.2, "Preparing image...")
96
-
97
- # Create a temporary file for the image
98
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
99
- temp_file.write(image_bytes)
100
- temp_image_path = temp_file.name
101
-
102
- # Create temporary JavaScript file
103
- with tempfile.NamedTemporaryFile(mode='w', suffix='.js', delete=False) as js_file:
104
- js_code = create_js_inference_script(temp_image_path, prompt.replace('"', '\\"'), hf_token)
105
- js_file.write(js_code)
106
- js_file_path = js_file.name
107
-
108
- try:
109
- if progress_callback:
110
- progress_callback(0.3, "Running JavaScript inference...")
111
-
112
- # Run the JavaScript code
113
- result = subprocess.run(
114
- ['node', js_file_path],
115
- capture_output=True,
116
- text=True,
117
- timeout=300 # 5 minute timeout
118
- )
119
-
120
- if progress_callback:
121
- progress_callback(0.8, "Processing result...")
122
-
123
- if result.returncode != 0:
124
- raise gr.Error(f"JavaScript inference failed: {result.stderr}")
125
-
126
- # Parse the JSON output
127
- try:
128
- output = json.loads(result.stdout.strip())
129
- except json.JSONDecodeError:
130
- raise gr.Error(f"Failed to parse JavaScript output: {result.stdout}")
131
-
132
- if not output.get('success'):
133
- raise gr.Error(f"Inference error: {output.get('error', 'Unknown error')}")
134
-
135
- if progress_callback:
136
- progress_callback(0.9, "Decoding image...")
137
-
138
- # Decode base64 image
139
- import base64
140
- image_data = base64.b64decode(output['image_base64'])
141
-
142
- if progress_callback:
143
- progress_callback(1.0, "Complete!")
144
-
145
- return image_data
146
-
147
- except subprocess.TimeoutExpired:
148
- raise gr.Error("Inference timed out. Please try again.")
149
- except Exception as e:
150
- raise gr.Error(f"Error running JavaScript inference: {str(e)}")
151
- finally:
152
- # Clean up temporary files
153
- try:
154
- os.unlink(temp_image_path)
155
- os.unlink(js_file_path)
156
- except:
157
- pass
158
 
159
  # --- Core Inference Function for ChatInterface ---
160
- def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()):
 
161
  """
162
  Performs image generation or editing based on user input from the chat interface.
163
  """
 
 
 
 
164
  prompt = message["text"]
165
  files = message["files"]
166
 
@@ -170,12 +55,12 @@ def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps,
170
  if randomize_seed:
171
  seed = random.randint(0, MAX_SEED)
172
 
 
 
 
173
  if files:
174
  print(f"Received image: {files[0]}")
175
  try:
176
- # Register HEIF opener with PIL for AVIF/HEIF support
177
- pillow_heif.register_heif_opener()
178
-
179
  # Try to open and convert the image
180
  input_image = Image.open(files[0])
181
  # Convert to RGB if needed (handles RGBA, P, etc.)
@@ -183,42 +68,31 @@ def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps,
183
  input_image = input_image.convert("RGB")
184
  # Auto-orient the image based on EXIF data
185
  input_image = ImageOps.exif_transpose(input_image)
186
-
187
- # Convert PIL image to bytes
188
- img_byte_arr = io.BytesIO()
189
- input_image.save(img_byte_arr, format='PNG')
190
- img_byte_arr.seek(0)
191
- image_bytes = img_byte_arr.getvalue()
192
-
193
  except Exception as e:
194
  raise gr.Error(f"Could not process the uploaded image: {str(e)}. Please try uploading a different image format (JPEG, PNG, WebP).")
195
 
196
- progress(0.1, desc="Processing image...")
 
 
 
 
 
 
197
  else:
198
- # For text-to-image, we need a placeholder image or handle differently
199
- # FLUX.1 Kontext is primarily an image-to-image model
200
- raise gr.Error("This model (FLUX.1 Kontext) requires an input image. Please upload an image to edit.")
201
-
202
- try:
203
- # Make API request using JavaScript
204
- result_bytes = query_api_js(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=progress)
205
-
206
- # Try to convert response bytes to PIL Image
207
- try:
208
- image = Image.open(io.BytesIO(result_bytes))
209
- except Exception as img_error:
210
- print(f"Failed to open image: {img_error}")
211
- print(f"Image bytes type: {type(result_bytes)}, length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}")
212
- raise gr.Error(f"Could not process API response as image. Response length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}")
213
-
214
- progress(1.0, desc="Complete!")
215
- return gr.Image(value=image)
216
-
217
- except gr.Error:
218
- # Re-raise gradio errors as-is
219
- raise
220
- except Exception as e:
221
- raise gr.Error(f"Failed to generate image: {str(e)}")
222
 
223
  # --- UI Definition using gr.ChatInterface ---
224
 
@@ -227,24 +101,26 @@ randomize_checkbox = gr.Checkbox(label="Randomize seed", value=False)
227
  guidance_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5)
228
  steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1)
229
 
 
 
 
 
230
  demo = gr.ChatInterface(
231
  fn=chat_fn,
232
- title="FLUX.1 Kontext [dev] - HF Inference Client (JS)",
233
  description="""<p style='text-align: center;'>
234
- A simple chat UI for the <b>FLUX.1 Kontext [dev]</b> model using Hugging Face Inference Client via JavaScript.
235
  <br>
236
- <b>Upload an image</b> and type your editing instructions (e.g., "Turn the cat into a tiger", "Add a hat").
237
  <br>
238
- This model specializes in understanding context and making precise edits to your images.
239
  <br>
240
  Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
241
- <br>
242
- <b>Requirements:</b> Node.js and npm must be installed. Uses HF_TOKEN environment variable.
243
  </p>""",
244
- multimodal=True,
245
  textbox=gr.MultimodalTextbox(
246
  file_types=["image"],
247
- placeholder="Upload an image and type your editing instructions...",
248
  render=False
249
  ),
250
  additional_inputs=[
@@ -253,6 +129,7 @@ demo = gr.ChatInterface(
253
  guidance_slider,
254
  steps_slider
255
  ],
 
256
  theme="soft"
257
  )
258
 
 
1
  import gradio as gr
2
  import numpy as np
3
+ import spaces
4
+ import torch
5
  import random
6
  import os
7
  import tempfile
 
 
8
  from PIL import Image, ImageOps
9
  import pillow_heif # For HEIF/AVIF support
10
+
11
+ # Import the pipeline from diffusers
12
+ from diffusers import FluxKontextPipeline
13
 
14
  # --- Constants ---
15
  MAX_SEED = np.iinfo(np.int32).max
16
 
17
+ # --- Global pipeline variable ---
18
+ pipe = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ def load_model():
21
+ """Load the model on CPU first, then move to GPU when needed"""
22
+ global pipe
23
+ if pipe is None:
24
+ # Register HEIF opener with PIL for AVIF/HEIF support
25
+ pillow_heif.register_heif_opener()
26
+
27
+ # Get token from environment variable
28
+ hf_token = os.getenv("HF_TOKEN")
29
+ if hf_token:
30
+ pipe = FluxKontextPipeline.from_pretrained(
31
+ "black-forest-labs/FLUX.1-Kontext-dev",
32
+ torch_dtype=torch.bfloat16,
33
+ token=hf_token,
34
+ )
35
+ else:
36
+ raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.")
37
+ return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # --- Core Inference Function for ChatInterface ---
40
+ @spaces.GPU(duration=120) # Set duration based on expected inference time
41
+ def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress(track_tqdm=True)):
42
  """
43
  Performs image generation or editing based on user input from the chat interface.
44
  """
45
+ # Load and move model to GPU within the decorated function
46
+ pipe = load_model()
47
+ pipe = pipe.to("cuda")
48
+
49
  prompt = message["text"]
50
  files = message["files"]
51
 
 
55
  if randomize_seed:
56
  seed = random.randint(0, MAX_SEED)
57
 
58
+ generator = torch.Generator(device="cuda").manual_seed(int(seed))
59
+
60
+ input_image = None
61
  if files:
62
  print(f"Received image: {files[0]}")
63
  try:
 
 
 
64
  # Try to open and convert the image
65
  input_image = Image.open(files[0])
66
  # Convert to RGB if needed (handles RGBA, P, etc.)
 
68
  input_image = input_image.convert("RGB")
69
  # Auto-orient the image based on EXIF data
70
  input_image = ImageOps.exif_transpose(input_image)
 
 
 
 
 
 
 
71
  except Exception as e:
72
  raise gr.Error(f"Could not process the uploaded image: {str(e)}. Please try uploading a different image format (JPEG, PNG, WebP).")
73
 
74
+ image = pipe(
75
+ image=input_image,
76
+ prompt=prompt,
77
+ guidance_scale=guidance_scale,
78
+ num_inference_steps=steps,
79
+ generator=generator,
80
+ ).images[0]
81
  else:
82
+ print(f"Received prompt for text-to-image: {prompt}")
83
+ image = pipe(
84
+ prompt=prompt,
85
+ guidance_scale=guidance_scale,
86
+ num_inference_steps=steps,
87
+ generator=generator,
88
+ ).images[0]
89
+
90
+ # Move model back to CPU to free GPU memory
91
+ pipe = pipe.to("cpu")
92
+ torch.cuda.empty_cache()
93
+
94
+ # Return the PIL Image as a Gradio Image component
95
+ return gr.Image(value=image)
 
 
 
 
 
 
 
 
 
 
96
 
97
  # --- UI Definition using gr.ChatInterface ---
98
 
 
101
  guidance_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5)
102
  steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1)
103
 
104
+ # --- Examples without external URLs ---
105
+ # Remove examples temporarily to avoid format issues
106
+ examples = None
107
+
108
  demo = gr.ChatInterface(
109
  fn=chat_fn,
110
+ title="FLUX.1 Kontext [dev]",
111
  description="""<p style='text-align: center;'>
112
+ A simple chat UI for the <b>FLUX.1 Kontext</b> model running on ZeroGPU.
113
  <br>
114
+ To edit an image, upload it and type your instructions (e.g., "Add a hat").
115
  <br>
116
+ To generate an image, just type a prompt (e.g., "A photo of an astronaut on a horse").
117
  <br>
118
  Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
 
 
119
  </p>""",
120
+ multimodal=True, # This is important for MultimodalTextbox to work
121
  textbox=gr.MultimodalTextbox(
122
  file_types=["image"],
123
+ placeholder="Type a prompt and/or upload an image...",
124
  render=False
125
  ),
126
  additional_inputs=[
 
129
  guidance_slider,
130
  steps_slider
131
  ],
132
+ examples=examples,
133
  theme="soft"
134
  )
135