Gemini899 commited on
Commit
e3f032e
·
verified ·
1 Parent(s): aba8cf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -210
app.py CHANGED
@@ -3,103 +3,69 @@ import gradio as gr
3
  import re
4
  from PIL import Image
5
  import io
 
6
  import os
7
  import numpy as np
8
  import torch
9
  from diffusers import FluxImg2ImgPipeline
10
- import tempfile
11
- import secrets
12
- import uuid
13
- import shutil
14
- import ssl
15
  from cryptography.fernet import Fernet
16
- import base64
17
- import hashlib
18
- import time
19
- import threading
20
 
21
- # Global encryption key for this session
22
- ENCRYPTION_KEY = Fernet.generate_key()
23
- cipher_suite = Fernet(ENCRYPTION_KEY)
24
 
25
- # Configure SSL context for secure connections - FIXED TO AVOID DEPRECATION WARNING
26
- ssl_context = ssl.create_default_context()
27
- # Use the recommended modern approach instead of deprecated options
28
- ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
29
- ssl_context.set_ciphers('ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20')
30
- ssl_context.check_hostname = True
31
- ssl_context.verify_mode = ssl.CERT_REQUIRED
32
 
33
- # Secure temporary directory manager
34
- class SecureTempManager:
35
- def __init__(self):
36
- self.temp_dir = tempfile.mkdtemp(prefix='flux_secure_')
37
- self.cleanup_timeout = 30 # seconds
38
- self.file_registry = {}
39
-
40
- def get_secure_path(self, prefix="img"):
41
- """Generate a secure random filename in the temp directory"""
42
- filename = f"{prefix}_{uuid.uuid4().hex}_{secrets.token_hex(8)}.png"
43
- filepath = os.path.join(self.temp_dir, filename)
44
-
45
- # Register file for cleanup
46
- self.file_registry[filepath] = time.time()
47
-
48
- return filepath
 
 
49
 
50
- def cleanup_old_files(self):
51
- """Clean up files older than the timeout"""
52
- current_time = time.time()
53
- for filepath, created_time in list(self.file_registry.items()):
54
- if current_time - created_time > self.cleanup_timeout:
55
- try:
56
- if os.path.exists(filepath):
57
- # Securely delete by overwriting with random data
58
- file_size = os.path.getsize(filepath)
59
- with open(filepath, 'wb') as f:
60
- f.write(os.urandom(file_size))
61
- # Then delete
62
- os.remove(filepath)
63
- # Remove from registry
64
- del self.file_registry[filepath]
65
- except Exception as e:
66
- print(f"Error cleaning up file {filepath}: {e}")
67
 
68
- def cleanup_all(self):
69
- """Clean up all files and remove temp directory"""
70
- # First clean up individual files
71
- for filepath in list(self.file_registry.keys()):
72
- try:
73
- if os.path.exists(filepath):
74
- os.remove(filepath)
75
- del self.file_registry[filepath]
76
- except:
77
- pass
78
-
79
- # Then remove the directory
80
- try:
81
- if os.path.exists(self.temp_dir):
82
- shutil.rmtree(self.temp_dir)
83
- except:
84
- pass
85
-
86
- # Initialize secure temp manager
87
- secure_temp = SecureTempManager()
88
-
89
- # Start a thread to periodically clean up old files
90
- def cleanup_thread_function():
91
- while True:
92
- secure_temp.cleanup_old_files()
93
- time.sleep(5) # Check every 5 seconds
94
-
95
- cleanup_thread = threading.Thread(target=cleanup_thread_function)
96
- cleanup_thread.daemon = True # Thread will exit when main program exits
97
- cleanup_thread.start()
98
 
99
- # Initialize model with proper settings
100
- dtype = torch.bfloat16
101
- device = "cuda" if torch.cuda.is_available() else "cpu"
102
- pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(device)
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def sanitize_prompt(prompt):
105
  # Allow only alphanumeric characters, spaces, and basic punctuation
@@ -126,71 +92,21 @@ def adjust_to_multiple_of_32(width: int, height: int):
126
  height = height - (height % 32)
127
  return width, height
128
 
129
- # Function to securely handle image data
130
- def secure_image_handler(image):
131
- """Process image securely without exposing it to the file system"""
132
- if image is None:
133
- return None
134
-
135
- # If the image is already a PIL Image, use it directly
136
- if isinstance(image, Image.Image):
137
- return image
138
-
139
- # Otherwise, assume it's a file path or binary data
140
- try:
141
- if isinstance(image, str) and os.path.exists(image):
142
- # It's a file path, load it securely
143
- with open(image, 'rb') as f:
144
- img_data = f.read()
145
-
146
- # Immediately delete the original file if it's in our temp directory
147
- if image.startswith(secure_temp.temp_dir):
148
- try:
149
- os.remove(image)
150
- except:
151
- pass
152
-
153
- # Create image from binary data
154
- return Image.open(io.BytesIO(img_data))
155
- elif isinstance(image, bytes):
156
- # It's binary data
157
- return Image.open(io.BytesIO(image))
158
- except Exception as e:
159
- print(f"Error processing image: {e}")
160
- return None
161
-
162
  @spaces.GPU(duration=120)
163
- def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4, progress=gr.Progress(track_tqdm=True)):
164
  progress(0, desc="Starting")
165
-
166
- # Sanitize input
167
- prompt = sanitize_prompt(prompt)
168
-
169
  def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
170
- # Secure image handling
171
- image = secure_image_handler(image)
172
-
173
  if image is None:
174
- print("Empty input image returned")
175
  return None
176
-
177
  generator = torch.Generator(device).manual_seed(seed)
178
  fit_width, fit_height = convert_to_fit_size(image.size)
179
  width, height = adjust_to_multiple_of_32(fit_width, fit_height)
180
  image = image.resize((width, height), Image.LANCZOS)
181
 
182
- # Process the image
183
- output = pipe(
184
- prompt=prompt,
185
- image=image,
186
- generator=generator,
187
- strength=strength,
188
- width=width,
189
- height=height,
190
- guidance_scale=0,
191
- num_inference_steps=num_inference_steps,
192
- max_sequence_length=256
193
- )
194
 
195
  pil_image = output.images[0]
196
  new_width, new_height = pil_image.size
@@ -200,11 +116,24 @@ def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step
200
  return resized_image
201
  return pil_image
202
 
203
- # Process the image
204
  output = process_img2img(image, prompt, strength, seed, inference_step)
205
 
206
- # Return the image directly (Gradio will handle it)
207
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  def read_file(path: str) -> str:
210
  with open(path, 'r', encoding='utf-8') as f:
@@ -236,98 +165,94 @@ css="""
236
  .text {
237
  font-size: 16px;
238
  }
239
- """
240
 
241
- # Custom HTTP headers for security
242
- custom_headers = {
243
- "Strict-Transport-Security": "max-age=63072000; includeSubDomains; preload",
244
- "X-Content-Type-Options": "nosniff",
245
- "X-Frame-Options": "SAMEORIGIN",
246
- "Content-Security-Policy": "default-src 'self'; img-src 'self' data:; style-src 'self' 'unsafe-inline';",
247
- "Referrer-Policy": "strict-origin-when-cross-origin",
248
- "Permissions-Policy": "camera=(), microphone=(), geolocation=()",
249
- "Cache-Control": "no-store, max-age=0"
250
  }
 
251
 
252
- # Create Gradio app with enhanced security
253
  with gr.Blocks(css=css, elem_id="demo-container") as demo:
 
 
 
254
  with gr.Column():
255
  gr.HTML(read_file("demo_header.html"))
256
  gr.HTML(read_file("demo_tools.html"))
257
  with gr.Row():
258
  with gr.Column():
259
- image = gr.Image(
260
- height=800,
261
- sources=['upload','clipboard'],
262
- image_mode='RGB',
263
- elem_id="image_upload",
264
- type="pil",
265
- label="Upload"
266
- )
267
  with gr.Row(elem_id="prompt-container", equal_height=False):
268
  with gr.Row():
269
- prompt = gr.Textbox(
270
- label="Prompt",
271
- value="a women",
272
- placeholder="Your prompt (what you want in place of what is erased)",
273
- elem_id="prompt"
274
- )
275
 
276
  btn = gr.Button("Img2Img", elem_id="run_button", variant="primary")
277
 
278
  with gr.Accordion(label="Advanced Settings", open=False):
279
  with gr.Row(equal_height=True):
280
- strength = gr.Number(value=0.75, minimum=0, maximum=0.75, step=0.01, label="strength")
281
- seed = gr.Number(value=100, minimum=0, step=1, label="seed")
282
- inference_step = gr.Number(value=4, minimum=1, step=4, label="inference_step")
283
- id_input=gr.Text(label="Name", visible=False)
 
284
 
285
  with gr.Column():
286
- image_out = gr.Image(
287
- height=800,
288
- sources=[],
289
- label="Output",
290
- elem_id="output-img",
291
- format="jpg"
292
- )
293
 
 
294
  gr.Examples(
295
  examples=[
296
- ["examples/draw_input.jpg", "examples/draw_output.jpg", "a women ,eyes closed,mouth opened"],
297
- ["examples/draw-gimp_input.jpg", "examples/draw-gimp_output.jpg", "a women ,eyes closed,mouth opened"],
298
- ["examples/gimp_input.jpg", "examples/gimp_output.jpg", "a women ,hand on neck"],
299
- ["examples/inpaint_input.jpg", "examples/inpaint_output.jpg", "a women ,hand on neck"]
300
  ],
301
  inputs=[image, image_out, prompt],
302
  )
303
- gr.HTML(
304
- gr.HTML(read_file("demo_footer.html"))
305
- )
 
 
 
 
 
 
 
306
  gr.on(
307
  triggers=[btn.click, prompt.submit],
308
- fn=process_images,
309
- inputs=[image, prompt, strength, seed, inference_step],
310
- outputs=[image_out]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  )
312
-
313
- # Register shutdown handler to clean up
314
- import atexit
315
- atexit.register(secure_temp.cleanup_all)
316
 
317
  if __name__ == "__main__":
318
- # Launch with security settings - FIXED: Removed 'enable_queue' parameter
319
- demo.launch(
320
- share=True,
321
- show_error=True,
322
- favicon_path=None,
323
- server_name="0.0.0.0", # Listen on all interfaces
324
- server_port=7860, # Default Gradio port
325
- inbrowser=False,
326
- debug=False, # Disable in production
327
- quiet=True, # Less logging for security
328
- height=900,
329
- width=1600,
330
- max_threads=20,
331
- auth=None, # Enable if you need authentication
332
- root_path=""
333
- )
 
3
  import re
4
  from PIL import Image
5
  import io
6
+ import base64
7
  import os
8
  import numpy as np
9
  import torch
10
  from diffusers import FluxImg2ImgPipeline
 
 
 
 
 
11
  from cryptography.fernet import Fernet
12
+ from cryptography.hazmat.primitives import hashes
13
+ from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
 
 
14
 
15
+ dtype = torch.bfloat16
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
17
 
18
+ pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(device)
 
 
 
 
 
 
19
 
20
+ # Encryption setup
21
+ def generate_key(password, salt=None):
22
+ if salt is None:
23
+ salt = os.urandom(16)
24
+ kdf = PBKDF2HMAC(
25
+ algorithm=hashes.SHA256(),
26
+ length=32,
27
+ salt=salt,
28
+ iterations=100000,
29
+ )
30
+ key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
31
+ return key, salt
32
+
33
+ def encrypt_image(image, password="default_password"):
34
+ # Convert PIL Image to bytes
35
+ img_byte_arr = io.BytesIO()
36
+ image.save(img_byte_arr, format='PNG')
37
+ img_byte_arr = img_byte_arr.getvalue()
38
 
39
+ # Generate key for encryption
40
+ key, salt = generate_key(password)
41
+ cipher = Fernet(key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # Encrypt the image bytes
44
+ encrypted_data = cipher.encrypt(img_byte_arr)
45
+
46
+ # Return the encrypted data and salt (needed for decryption)
47
+ return {
48
+ 'encrypted_data': base64.b64encode(encrypted_data).decode('utf-8'),
49
+ 'salt': base64.b64encode(salt).decode('utf-8'),
50
+ 'original_width': image.width,
51
+ 'original_height': image.height
52
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ def decrypt_image(encrypted_data_dict, password="default_password"):
55
+ # Extract the encrypted data and salt
56
+ encrypted_data = base64.b64decode(encrypted_data_dict['encrypted_data'])
57
+ salt = base64.b64decode(encrypted_data_dict['salt'])
58
+
59
+ # Regenerate the key using the provided salt
60
+ key, _ = generate_key(password, salt)
61
+ cipher = Fernet(key)
62
+
63
+ # Decrypt the data
64
+ decrypted_data = cipher.decrypt(encrypted_data)
65
+
66
+ # Convert bytes back to PIL Image
67
+ image = Image.open(io.BytesIO(decrypted_data))
68
+ return image
69
 
70
  def sanitize_prompt(prompt):
71
  # Allow only alphanumeric characters, spaces, and basic punctuation
 
92
  height = height - (height % 32)
93
  return width, height
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  @spaces.GPU(duration=120)
96
+ def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4, encrypt_password="default_password", progress=gr.Progress(track_tqdm=True)):
97
  progress(0, desc="Starting")
98
+
 
 
 
99
  def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
 
 
 
100
  if image is None:
101
+ print("empty input image returned")
102
  return None
 
103
  generator = torch.Generator(device).manual_seed(seed)
104
  fit_width, fit_height = convert_to_fit_size(image.size)
105
  width, height = adjust_to_multiple_of_32(fit_width, fit_height)
106
  image = image.resize((width, height), Image.LANCZOS)
107
 
108
+ output = pipe(prompt=prompt, image=image, generator=generator, strength=strength, width=width, height=height,
109
+ guidance_scale=0, num_inference_steps=num_inference_steps, max_sequence_length=256)
 
 
 
 
 
 
 
 
 
 
110
 
111
  pil_image = output.images[0]
112
  new_width, new_height = pil_image.size
 
116
  return resized_image
117
  return pil_image
118
 
 
119
  output = process_img2img(image, prompt, strength, seed, inference_step)
120
 
121
+ # Encrypt the output image
122
+ if output is not None:
123
+ encrypted_output = encrypt_image(output, encrypt_password)
124
+
125
+ # For display purposes, we'll create a placeholder image with text indicating encryption
126
+ placeholder = Image.new('RGB', (output.width, output.height), color=(220, 220, 220))
127
+ return {
128
+ "display_image": placeholder,
129
+ "encrypted_data": encrypted_output
130
+ }
131
+ return None
132
+
133
+ def save_encrypted_image(encrypted_data, filename="encrypted_image.enc"):
134
+ with open(filename, 'w') as f:
135
+ json.dump(encrypted_data, f)
136
+ return f"Encrypted image saved as {filename}"
137
 
138
  def read_file(path: str) -> str:
139
  with open(path, 'r', encoding='utf-8') as f:
 
165
  .text {
166
  font-size: 16px;
167
  }
 
168
 
169
+ .encryption-notice {
170
+ background-color: #f0f0f0;
171
+ padding: 15px;
172
+ border-radius: 5px;
173
+ margin-top: 10px;
174
+ text-align: center;
 
 
 
175
  }
176
+ """
177
 
 
178
  with gr.Blocks(css=css, elem_id="demo-container") as demo:
179
+ # Store encrypted data in a state variable
180
+ encrypted_output_state = gr.State(None)
181
+
182
  with gr.Column():
183
  gr.HTML(read_file("demo_header.html"))
184
  gr.HTML(read_file("demo_tools.html"))
185
  with gr.Row():
186
  with gr.Column():
187
+ image = gr.Image(height=800, sources=['upload','clipboard'], image_mode='RGB', elem_id="image_upload", type="pil", label="Upload")
 
 
 
 
 
 
 
188
  with gr.Row(elem_id="prompt-container", equal_height=False):
189
  with gr.Row():
190
+ prompt = gr.Textbox(label="Prompt", value="a women", placeholder="Your prompt (what you want in place of what is erased)", elem_id="prompt")
 
 
 
 
 
191
 
192
  btn = gr.Button("Img2Img", elem_id="run_button", variant="primary")
193
 
194
  with gr.Accordion(label="Advanced Settings", open=False):
195
  with gr.Row(equal_height=True):
196
+ strength = gr.Number(value=0.75, minimum=0, maximum=0.75, step=0.01, label="Strength")
197
+ seed = gr.Number(value=100, minimum=0, step=1, label="Seed")
198
+ inference_step = gr.Number(value=4, minimum=1, step=4, label="Inference Steps")
199
+ encrypt_password = gr.Textbox(label="Encryption Password", value="default_password", type="password")
200
+ id_input = gr.Text(label="Name", visible=False)
201
 
202
  with gr.Column():
203
+ # Display placeholder image
204
+ image_out = gr.Image(height=800, sources=[], label="Output (Encrypted)", elem_id="output-img", format="jpg")
205
+ encryption_notice = gr.HTML('<div class="encryption-notice">The output image is encrypted. Use the Save button to download the encrypted file.</div>')
206
+ save_btn = gr.Button("Save Encrypted Image")
207
+ save_result = gr.Text(label="Save Result")
 
 
208
 
209
+ # Examples section
210
  gr.Examples(
211
  examples=[
212
+ ["examples/draw_input.jpg", "examples/draw_output.jpg", "a women, eyes closed, mouth opened"],
213
+ ["examples/draw-gimp_input.jpg", "examples/draw-gimp_output.jpg", "a women, eyes closed, mouth opened"],
214
+ ["examples/gimp_input.jpg", "examples/gimp_output.jpg", "a women, hand on neck"],
215
+ ["examples/inpaint_input.jpg", "examples/inpaint_output.jpg", "a women, hand on neck"]
216
  ],
217
  inputs=[image, image_out, prompt],
218
  )
219
+
220
+ gr.HTML(read_file("demo_footer.html"))
221
+
222
+ # Process images and encrypt outputs
223
+ def handle_image_generation(image, prompt, strength, seed, inference_step, encrypt_password):
224
+ result = process_images(image, prompt, strength, seed, inference_step, encrypt_password)
225
+ if result:
226
+ return result["display_image"], result["encrypted_data"]
227
+ return None, None
228
+
229
  gr.on(
230
  triggers=[btn.click, prompt.submit],
231
+ fn=handle_image_generation,
232
+ inputs=[image, prompt, strength, seed, inference_step, encrypt_password],
233
+ outputs=[image_out, encrypted_output_state]
234
+ )
235
+
236
+ # Save encrypted image
237
+ def handle_save_encrypted(encrypted_data):
238
+ if encrypted_data:
239
+ import json
240
+ import tempfile
241
+ import os
242
+
243
+ # Create a temporary file with the encrypted data
244
+ fd, path = tempfile.mkstemp(suffix='.encimg')
245
+ with os.fdopen(fd, 'w') as f:
246
+ json.dump(encrypted_data, f)
247
+
248
+ return f"Encrypted image saved to {path}"
249
+ return "No encrypted image to save"
250
+
251
+ save_btn.click(
252
+ fn=handle_save_encrypted,
253
+ inputs=[encrypted_output_state],
254
+ outputs=[save_result]
255
  )
 
 
 
 
256
 
257
  if __name__ == "__main__":
258
+ demo.launch(share=True, show_error=True)