img2img_test / app.py
Gemini899's picture
Update app.py
547723f verified
raw
history blame
10.2 kB
import spaces
import gradio as gr
import re
from PIL import Image
import io
import base64
import os
import json
import numpy as np
import torch
from diffusers import FluxImg2ImgPipeline
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(device)
# Encryption setup
def generate_key(password, salt=None):
if salt is None:
salt = os.urandom(16)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
return key, salt
def encrypt_image(image, password="default_password"):
# Convert PIL Image to bytes
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
# Generate key for encryption
key, salt = generate_key(password)
cipher = Fernet(key)
# Encrypt the image bytes
encrypted_data = cipher.encrypt(img_byte_arr)
# Return the encrypted data and salt (needed for decryption)
return {
'encrypted_data': base64.b64encode(encrypted_data).decode('utf-8'),
'salt': base64.b64encode(salt).decode('utf-8'),
'original_width': image.width,
'original_height': image.height
}
def decrypt_image(encrypted_data_dict, password="default_password"):
# Extract the encrypted data and salt
encrypted_data = base64.b64decode(encrypted_data_dict['encrypted_data'])
salt = base64.b64decode(encrypted_data_dict['salt'])
# Regenerate the key using the provided salt
key, _ = generate_key(password, salt)
cipher = Fernet(key)
# Decrypt the data
decrypted_data = cipher.decrypt(encrypted_data)
# Convert bytes back to PIL Image
image = Image.open(io.BytesIO(decrypted_data))
return image
def sanitize_prompt(prompt):
# Allow only alphanumeric characters, spaces, and basic punctuation
allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
sanitized_prompt = allowed_chars.sub("", prompt)
return sanitized_prompt
def convert_to_fit_size(original_width_and_height, maximum_size=2048):
width, height = original_width_and_height
if width <= maximum_size and height <= maximum_size:
return width, height
if width > height:
scaling_factor = maximum_size / width
else:
scaling_factor = maximum_size / height
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
return new_width, new_height
def adjust_to_multiple_of_32(width: int, height: int):
width = width - (width % 32)
height = height - (height % 32)
return width, height
@spaces.GPU(duration=120)
def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4,
encrypt_password="default_password", progress=gr.Progress(track_tqdm=True)):
progress(0, desc="Starting")
def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
if image is None:
print("empty input image returned")
return None
generator = torch.Generator(device).manual_seed(seed)
fit_width, fit_height = convert_to_fit_size(image.size)
width, height = adjust_to_multiple_of_32(fit_width, fit_height)
image = image.resize((width, height), Image.LANCZOS)
output = pipe(
prompt=prompt,
image=image,
generator=generator,
strength=strength,
width=width,
height=height,
guidance_scale=0,
num_inference_steps=num_inference_steps,
max_sequence_length=256
)
pil_image = output.images[0]
new_width, new_height = pil_image.size
if (new_width != fit_width) or (new_height != fit_height):
resized_image = pil_image.resize((fit_width, fit_height), Image.LANCZOS)
return resized_image
return pil_image
output = process_img2img(image, prompt, strength, seed, inference_step)
# Encrypt the output image
if output is not None:
encrypted_output = encrypt_image(output, encrypt_password)
# For display purposes, we'll create a placeholder image with text indicating encryption
placeholder = Image.new('RGB', (output.width, output.height), color=(220, 220, 220))
return {
"display_image": placeholder,
"encrypted_data": encrypted_output
}
return None
def save_encrypted_image(encrypted_data, filename="encrypted_image.enc"):
with open(filename, 'w') as f:
json.dump(encrypted_data, f)
return f"Encrypted image saved as {filename}"
def read_file(path: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css = """
#col-left {
margin: 0 auto;
max-width: 640px;
}
#col-right {
margin: 0 auto;
max-width: 640px;
}
.grid-container {
display: flex;
align-items: center;
justify-content: center;
gap:10px
}
.image {
width: 128px;
height: 128px;
object-fit: cover;
}
.text {
font-size: 16px;
}
.encryption-notice {
background-color: #f0f0f0;
padding: 15px;
border-radius: 5px;
margin-top: 10px;
text-align: center;
}
"""
with gr.Blocks(css=css, elem_id="demo-container") as demo:
# Store encrypted data in a state variable
encrypted_output_state = gr.State(None)
with gr.Column():
gr.HTML(read_file("demo_header.html"))
gr.HTML(read_file("demo_tools.html"))
with gr.Row():
with gr.Column():
image = gr.Image(
height=800,
sources=['upload', 'clipboard'],
image_mode='RGB',
elem_id="image_upload",
type="pil",
label="Upload"
)
with gr.Row(elem_id="prompt-container", equal_height=False):
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
value="a women",
placeholder="Your prompt (what you want in place of what is erased)",
elem_id="prompt"
)
btn = gr.Button("Img2Img", elem_id="run_button", variant="primary")
with gr.Accordion(label="Advanced Settings", open=False):
with gr.Row(equal_height=True):
strength = gr.Number(
value=0.75, minimum=0, maximum=0.75, step=0.01, label="Strength"
)
seed = gr.Number(
value=100, minimum=0, step=1, label="Seed"
)
inference_step = gr.Number(
value=4, minimum=1, step=4, label="Inference Steps"
)
encrypt_password = gr.Textbox(
label="Encryption Password",
value="default_password",
type="password"
)
id_input = gr.Text(label="Name", visible=False)
with gr.Column():
# Display placeholder image
image_out = gr.Image(
height=800,
sources=[],
label="Output (Encrypted)",
elem_id="output-img",
format="jpg"
)
encryption_notice = gr.HTML(
'<div class="encryption-notice">'
'The output image is encrypted. Use the Save button to download the encrypted file.'
'</div>'
)
save_btn = gr.Button("Save Encrypted Image")
save_result = gr.Text(label="Save Result")
# Examples section
gr.Examples(
examples=[
["examples/draw_input.jpg", "examples/draw_output.jpg", "a women, eyes closed, mouth opened"],
["examples/draw-gimp_input.jpg", "examples/draw-gimp_output.jpg", "a women, eyes closed, mouth opened"],
["examples/gimp_input.jpg", "examples/gimp_output.jpg", "a women, hand on neck"],
["examples/inpaint_input.jpg", "examples/inpaint_output.jpg", "a women, hand on neck"]
],
inputs=[image, image_out, prompt],
)
gr.HTML(read_file("demo_footer.html"))
# Process images and encrypt outputs
def handle_image_generation(image, prompt, strength, seed, inference_step, encrypt_password):
result = process_images(image, prompt, strength, seed, inference_step, encrypt_password)
if result:
return result["display_image"], result["encrypted_data"]
return None, None
# >>>> CHANGED: Use .click() and .submit() with api_name
btn.click(
fn=handle_image_generation,
inputs=[image, prompt, strength, seed, inference_step, encrypt_password],
outputs=[image_out, encrypted_output_state],
api_name="/process_images" # Exposes handle_image_generation as /process_images
)
prompt.submit(
fn=handle_image_generation,
inputs=[image, prompt, strength, seed, inference_step, encrypt_password],
outputs=[image_out, encrypted_output_state],
api_name="/process_images" # Same endpoint
)
# <<<< END CHANGE
def handle_save_encrypted(encrypted_data):
if encrypted_data:
import tempfile
fd, path = tempfile.mkstemp(suffix='.encimg')
with os.fdopen(fd, 'w') as f:
json.dump(encrypted_data, f)
return f"Encrypted image saved to {path}"
return "No encrypted image to save"
save_btn.click(
fn=handle_save_encrypted,
inputs=[encrypted_output_state],
outputs=[save_result]
)
if __name__ == "__main__":
demo.launch(share=True, show_error=True)