DreamFuse / app.py
LL3RD's picture
test
f6e3a92
raw
history blame
21.8 kB
import gradio as gr
import spaces
from PIL import Image, ImageDraw, ImageOps
import base64, json
from io import BytesIO
import torch.nn.functional as F
import json
from typing import List
from dataclasses import dataclass, field
from dreamfuse_inference import DreamFuseInference, InferenceConfig
import numpy as np
import os
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
import torch
import subprocess
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
generated_images = []
RMBG_model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
RMBG_model = RMBG_model.to("cuda")
transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
@spaces.GPU
def remove_bg(image):
im = image.convert("RGB")
input_tensor = transform(im).unsqueeze(0).to("cuda")
with torch.no_grad():
preds = RMBG_model(input_tensor)[-1].sigmoid().cpu()[0].squeeze()
mask = transforms.ToPILImage()(preds).resize(im.size)
return mask
class DreamblendGUI:
def __init__(self):
self.examples = [
["./examples/9_02.png",
"./examples/9_01.png"],
]
self.examples = [[Image.open(x) for x in example] for example in self.examples]
self.css_style = self._get_css_style()
self.js_script = self._get_js_script()
def _get_css_style(self):
return """
body {
background: transparent;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
color: #fff;
}
.gradio-container {
max-width: 1200px;
margin: auto;
background: transparent;
border-radius: 10px;
padding: 20px;
box-shadow: 0px 2px 8px rgba(255,255,255,0.1);
}
h1, h2 {
text-align: center;
color: #fff;
}
#canvas_preview {
border: 2px dashed rgba(255,255,255,0.5);
padding: 10px;
background: transparent;
border-radius: 8px;
}
.gr-button {
background-color: #007bff;
border: none;
color: #fff;
padding: 10px 20px;
border-radius: 5px;
font-size: 16px;
cursor: pointer;
}
.gr-button:hover {
background-color: #0056b3;
}
#small-examples {
max-width: 200px !important;
width: 200px !important;
float: left;
margin-right: 20px;
}
"""
def _get_js_script(self):
return r"""
async () => {
window.updateTransformation = function() {
const img = document.getElementById('draggable-img');
const container = document.getElementById('canvas-container');
if (!img || !container) return;
const left = parseFloat(img.style.left) || 0;
const top = parseFloat(img.style.top) || 0;
const canvasSize = 400;
const data_original_width = parseFloat(img.getAttribute('data-original-width'));
const data_original_height = parseFloat(img.getAttribute('data-original-height'));
const bgWidth = parseFloat(container.dataset.bgWidth);
const bgHeight = parseFloat(container.dataset.bgHeight);
const scale_ratio = img.clientWidth / data_original_width;
const transformation = {
drag_left: left,
drag_top: top,
drag_width: img.clientWidth,
drag_height: img.clientHeight,
data_original_width: data_original_width,
data_original_height: data_original_height,
scale_ratio: scale_ratio
};
const transInput = document.querySelector("#transformation_info textarea");
if(transInput){
const newValue = JSON.stringify(transformation);
const nativeSetter = Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype, 'value').set;
nativeSetter.call(transInput, newValue);
transInput.dispatchEvent(new Event('input', { bubbles: true }));
console.log("Transformation info updated: ", newValue);
} else {
console.log("ๆ‰พไธๅˆฐ transformation_info ็š„ textarea ๅ…ƒ็ด ");
}
};
globalThis.initializeDrag = () => {
console.log("ๅˆๅง‹ๅŒ–ๆ‹–ๆ‹ฝไธŽ็ผฉๆ”พๅŠŸ่ƒฝ...");
const observer = new MutationObserver(() => {
const img = document.getElementById('draggable-img');
const container = document.getElementById('canvas-container');
const slider = document.getElementById('scale-slider');
if (img && container && slider) {
observer.disconnect();
console.log("็ป‘ๅฎšๆ‹–ๆ‹ฝไธŽ็ผฉๆ”พไบ‹ไปถ...");
img.ondragstart = (e) => { e.preventDefault(); return false; };
let offsetX = 0, offsetY = 0;
let isDragging = false;
let scaleAnchor = null;
img.addEventListener('mousedown', (e) => {
isDragging = true;
img.style.cursor = 'grabbing';
const imgRect = img.getBoundingClientRect();
offsetX = e.clientX - imgRect.left;
offsetY = e.clientY - imgRect.top;
img.style.transform = "none";
img.style.left = img.offsetLeft + "px";
img.style.top = img.offsetTop + "px";
console.log("mousedown: left=", img.style.left, "top=", img.style.top);
});
document.addEventListener('mousemove', (e) => {
if (!isDragging) return;
e.preventDefault();
const containerRect = container.getBoundingClientRect();
// ่ฎก็ฎ—ๅฝ“ๅ‰ๆ‹–ๆ‹ฝๅŽ็š„ๅๆ ‡๏ผˆๅŸบไบŽๅฎนๅ™จ๏ผ‰
let left = e.clientX - containerRect.left - offsetX;
let top = e.clientY - containerRect.top - offsetY;
// ๅ…่ฎธ็š„ๆ‹–ๆ‹ฝ่Œƒๅ›ด๏ผš
// ๆฐดๅนณๆ–นๅ‘ๅ…่ฎธๆœ€ๅฐ‘่ถ…ๅ‡บๅ›พๅƒไธ€ๅŠ๏ผšๆœ€ๅฐๅ€ผไธบ -img.clientWidth * (7/8)
// ๆฐดๅนณๆ–นๅ‘ๅ…่ฎธๆœ€ๅคš่ถ…ๅ‡บไธ€ๅŠ๏ผšๆœ€ๅคงๅ€ผไธบ containerRect.width - img.clientWidth * (1/8)
const minLeft = -img.clientWidth * (7/8);
const maxLeft = containerRect.width - img.clientWidth * (1/8);
// ๅž‚็›ดๆ–นๅ‘ๅ…่ฎธ่Œƒๅ›ด๏ผš
// ๆœ€ๅฐๅ€ผไธบ -img.clientHeight * (7/8)
// ๆœ€ๅคงๅ€ผไธบ containerRect.height - img.clientHeight * (1/8)
const minTop = -img.clientHeight * (7/8);
const maxTop = containerRect.height - img.clientHeight * (1/8);
// ้™ๅˆถ่Œƒๅ›ด
if (left < minLeft) left = minLeft;
if (left > maxLeft) left = maxLeft;
if (top < minTop) top = minTop;
if (top > maxTop) top = maxTop;
img.style.left = left + "px";
img.style.top = top + "px";
});
window.addEventListener('mouseup', (e) => {
if (isDragging) {
isDragging = false;
img.style.cursor = 'grab';
const containerRect = container.getBoundingClientRect();
const bgWidth = parseFloat(container.dataset.bgWidth);
const bgHeight = parseFloat(container.dataset.bgHeight);
const offsetLeft = (containerRect.width - bgWidth) / 2;
const offsetTop = (containerRect.height - bgHeight) / 2;
const absoluteLeft = parseFloat(img.style.left);
const absoluteTop = parseFloat(img.style.top);
const relativeX = absoluteLeft - offsetLeft;
const relativeY = absoluteTop - offsetTop;
document.getElementById("coordinate").textContent =
`ๅ‰ๆ™ฏๅ›พๅๆ ‡: (x=${relativeX.toFixed(2)}, y=${relativeY.toFixed(2)})`;
updateTransformation();
}
scaleAnchor = null;
});
slider.addEventListener('mousedown', (e) => {
const containerRect = container.getBoundingClientRect();
const imgRect = img.getBoundingClientRect();
scaleAnchor = {
x: imgRect.left + imgRect.width/2 - containerRect.left,
y: imgRect.top + imgRect.height/2 - containerRect.top
};
console.log("Slider mousedown, captured scaleAnchor: ", scaleAnchor);
});
slider.addEventListener('input', (e) => {
const scale = parseFloat(e.target.value);
const originalWidth = parseFloat(img.getAttribute('data-original-width'));
const originalHeight = parseFloat(img.getAttribute('data-original-height'));
const newWidth = originalWidth * scale;
const newHeight = originalHeight * scale;
const containerRect = container.getBoundingClientRect();
let centerX, centerY;
if (scaleAnchor) {
centerX = scaleAnchor.x;
centerY = scaleAnchor.y;
} else {
const imgRect = img.getBoundingClientRect();
centerX = imgRect.left + imgRect.width/2 - containerRect.left;
centerY = imgRect.top + imgRect.height/2 - containerRect.top;
}
const newLeft = centerX - newWidth/2;
const newTop = centerY - newHeight/2;
img.style.width = newWidth + "px";
img.style.height = newHeight + "px";
img.style.left = newLeft + "px";
img.style.top = newTop + "px";
console.log("slider: scale=", scale, "newWidth=", newWidth, "newHeight=", newHeight);
updateTransformation();
});
slider.addEventListener('mouseup', (e) => {
scaleAnchor = null;
});
}
});
observer.observe(document.body, { childList: true, subtree: true });
};
}
"""
def get_next_sequence(self, folder_path):
# ๅˆ—ๅ‡บๆ–‡ไปถๅคนไธญ็š„ๆ‰€ๆœ‰ๆ–‡ไปถๅ
filenames = os.listdir(folder_path)
# ๆๅ–ๆ–‡ไปถๅไธญ็š„ๅบๅˆ—ๅท้ƒจๅˆ†๏ผˆๅ‡่ฎพๆ˜ฏๅ‰ไธ‰ไฝๆ•ฐๅญ—๏ผ‰
sequences = [int(name.split('_')[0]) for name in filenames if name.split('_')[0].isdigit()]
# ๆ‰พๅˆฐๆœ€ๅคงๅบๅˆ—ๅท
max_sequence = max(sequences, default=-1)
# ่ฟ”ๅ›žไธ‹ไธ€ไฝๅบๅˆ—ๅท๏ผŒๆ ผๅผไธบไธ‰ไฝๆ•ฐๅญ—๏ผˆๅฆ‚002๏ผ‰
return f"{max_sequence + 1:03d}"
def pil_to_base64(self, img):
"""ๅฐ† PIL Image ่ฝฌไธบ base64 ๅญ—็ฌฆไธฒ๏ผŒPNG ๆ ผๅผไธ‹ไฟ็•™้€ๆ˜Ž้€š้“"""
if img is None:
return ""
if img.mode != "RGBA":
img = img.convert("RGBA")
buffered = BytesIO()
img.save(buffered, format="PNG", optimize=True)
img_bytes = buffered.getvalue()
base64_str = base64.b64encode(img_bytes).decode()
return f"data:image/png;base64,{base64_str}"
def resize_background_image(self, img, max_size=400):
"""ๅฐ†่ƒŒๆ™ฏๅ›พ็ญ‰ๆฏ”ไพ‹็ผฉๆ”พๅˆฐๆœ€้•ฟ่พนไธบ max_size๏ผˆ400๏ผ‰"""
if img is None:
return None
w, h = img.size
if w > max_size or h > max_size:
ratio = min(max_size / w, max_size / h)
new_w, new_h = int(w * ratio), int(h * ratio)
img = img.resize((new_w, new_h), Image.LANCZOS)
return img
def resize_draggable_image(self, img, max_size=400):
"""ๅฐ†ๅ‰ๆ™ฏๅ›พ็ญ‰ๆฏ”ไพ‹็ผฉๆ”พๅˆฐๆœ€้•ฟ่พนไธ่ถ…่ฟ‡ max_size๏ผˆ400๏ผ‰"""
if img is None:
return None
w, h = img.size
if w > max_size or h > max_size:
ratio = min(max_size / w, max_size / h)
new_w, new_h = int(w * ratio), int(h * ratio)
img = img.resize((new_w, new_h), Image.LANCZOS)
return img
def generate_html(self, background_img_b64, bg_width, bg_height, draggable_img_b64, draggable_width, draggable_height, canvas_size=400):
"""็”Ÿๆˆ้ข„่งˆ HTML ้กต้ข"""
html_code = f"""
<html>
<head>
<style>
body {{
margin: 0;
padding: 0;
text-align: center;
font-family: sans-serif;
background: transparent;
color: #fff;
}}
h2 {{
margin-top: 1rem;
}}
#scale-control {{
margin: 1rem auto;
width: 400px;
text-align: left;
}}
#scale-control label {{
font-size: 1rem;
margin-right: 0.5rem;
}}
#canvas-container {{
position: relative;
width: {canvas_size}px;
height: {canvas_size}px;
margin: 0 auto;
border: 1px dashed rgba(255,255,255,0.5);
overflow: hidden;
background-image: url('{background_img_b64}');
background-repeat: no-repeat;
background-position: center;
background-size: contain;
border-radius: 8px;
}}
#draggable-img {{
position: absolute;
cursor: grab;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
background-color: transparent;
}}
#coordinate {{
color: #fff;
margin-top: 1rem;
font-weight: bold;
}}
</style>
</head>
<body>
<h2>ๆ‹–ๆ‹ฝๅ‰ๆ™ฏๅ›พ๏ผˆๆ”ฏๆŒ็ผฉๆ”พ๏ผ‰</h2>
<div id="scale-control">
<label for="scale-slider">ๅ‰ๆ™ฏๅ›พ็ผฉๆ”พ:</label>
<input type="range" id="scale-slider" min="0.1" max="2" step="0.01" value="1">
</div>
<div id="canvas-container" data-bg-width="{bg_width}" data-bg-height="{bg_height}">
<img id="draggable-img"
src="{draggable_img_b64}"
alt="Draggable Image"
draggable="false"
data-original-width="{draggable_width}"
data-original-height="{draggable_height}"
/>
</div>
<p id="coordinate">ๅ‰ๆ™ฏๅ›พๅๆ ‡: (x=?, y=?)</p>
</body>
</html>
"""
return html_code
def on_upload(self, background_img, draggable_img):
"""ไธŠไผ ๅ›พ็‰‡ๅŽ็š„ๅค„็†"""
if background_img is None or draggable_img is None:
return "<p style='color:red;'>่ฏทๅ…ˆไธŠไผ ่ƒŒๆ™ฏๅ›พ็‰‡ๅ’Œๅฏๆ‹–ๆ‹ฝๅ›พ็‰‡ใ€‚</p>"
if draggable_img.mode != "RGB":
draggable_img = draggable_img.convert("RGB")
draggable_img_mask = remove_bg(draggable_img)
alpha_channel = draggable_img_mask.convert("L")
draggable_img = draggable_img.convert("RGBA")
draggable_img.putalpha(alpha_channel)
resized_bg = self.resize_background_image(background_img, max_size=400)
bg_w, bg_h = resized_bg.size
resized_fg = self.resize_draggable_image(draggable_img, max_size=400)
draggable_width, draggable_height = resized_fg.size
background_img_b64 = self.pil_to_base64(resized_bg)
draggable_img_b64 = self.pil_to_base64(resized_fg)
return self.generate_html(
background_img_b64, bg_w, bg_h,
draggable_img_b64, draggable_width, draggable_height,
canvas_size=400
), draggable_img
def create_gui(self):
config = InferenceConfig()
config.lora_id = 'LL3RD/DreamFuse'
pipeline = None
# pipeline = DreamFuseInference(config)
# pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
"""ๅˆ›ๅปบ Gradio ็•Œ้ข"""
with gr.Blocks(css=self.css_style) as demo:
modified_fg_state = gr.State()
gr.Markdown("# DreamFuse: 3 Easy Steps to Create Your Fusion Image")
gr.Markdown("1. Upload the foreground and background images you want to blend.")
gr.Markdown("2. Click 'Generate Canvas' to preview the result. You can then drag and resize the foreground object to position it as you like.")
gr.Markdown("3. Click 'Run Model' to create the final fused image.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### FG&BG Image Upload")
background_img_in = gr.Image(label="Background Image", type="pil", height=240, width=240)
draggable_img_in = gr.Image(label="Foreground Image", type="pil", image_mode="RGBA", height=240, width=240)
generate_btn = gr.Button("Generate Canvas")
with gr.Row():
gr.Examples(
examples=[self.examples[0]],
inputs=[background_img_in, draggable_img_in],
elem_id="small-examples"
)
with gr.Column(scale=1):
gr.Markdown("### Preview Region")
html_out = gr.HTML(label="drag and resize", elem_id="canvas_preview")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Parameters")
seed_slider = gr.Slider(minimum=-1, maximum=100000, step=1, label="Seed", value=12345)
cfg_slider = gr.Slider(minimum=1, maximum=10, step=0.1, label="CFG", value=3.5)
size_select = gr.Radio(
choices=["512", "768", "1024"],
value="512",
label="็”Ÿๆˆ่ดจ้‡(512-ๅทฎ 1024-ๅฅฝ)",
)
prompt_text = gr.Textbox(label="Prompt", placeholder="text prompt", value="")
text_strength = gr.Slider(minimum=1, maximum=10, step=1, label="Text Strength", value=1, visible=False)
enable_gui = gr.Checkbox(label="ๅฏ็”จGUI", value=True, visible=False)
enable_truecfg = gr.Checkbox(label="TrueCFG", value=False, visible=False)
with gr.Column(scale=1):
gr.Markdown("### Model Result")
model_generate_btn = gr.Button("Run Model")
transformation_text = gr.Textbox(label="Transformation Info", elem_id="transformation_info", visible=False)
model_output = gr.Image(label="Model Output", type="pil")
generate_btn.click(
fn=self.on_upload,
inputs=[background_img_in, draggable_img_in],
outputs=[html_out, modified_fg_state],
)
model_generate_btn.click(
# fn=pipeline.gradio_generate,
fn=self.pil_to_base64,
inputs=[background_img_in, modified_fg_state, transformation_text, seed_slider, \
prompt_text, enable_gui, cfg_slider, size_select, text_strength, enable_truecfg],
outputs=model_output
)
# ้กต้ขๅŠ ่ฝฝๅŽๅˆๅง‹ๅŒ–ๆ‹–ๆ‹ฝ/็ผฉๆ”พไบ‹ไปถ
demo.load(None, None, None, js=self.js_script)
generate_btn.click(fn=None, inputs=None, outputs=None, js="initializeDrag")
return demo
if __name__ == "__main__":
gui = DreamblendGUI()
demo = gui.create_gui()
demo.queue()
demo.launch()
# demo.launch(server_port=7789, ssr_mode=False)
# demo.launch(server_name="[::]", share=True)