yichenchenchen's picture
Update app.py
189b26e verified
import spaces
import gradio as gr
from PIL import Image
import os
import tempfile
import sys
import time
from accelerate.utils import set_seed
from huggingface_hub import snapshot_download
import subprocess
#os.system("pip install ./flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")
subprocess.check_call([sys.executable, "-m", "pip", "install", "./flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"])
def ensure_flash_attn():
try:
import flash_attn
print("当前 flash-attn 已安装,版本:", flash_attn.__version__)
except ImportError:
print("未安装 flash-attn")
ensure_flash_attn()
from inferencer import UniPicV2Inferencer
os.environ["HF_HUB_REQUEST_TIMEOUT"] = "60.0"
model_path = snapshot_download(repo_id="Skywork/UniPic2-Metaquery-9B")
qwen_vl_path = snapshot_download(repo_id="Qwen/Qwen2.5-VL-7B-Instruct")
inferencer = UniPicV2Inferencer(
model_path=model_path,
qwen_vl_path=qwen_vl_path,
quant="fp16"
)
#inferencer.pipeline = inferencer._init_pipeline()
TEMP_DIR = tempfile.mkdtemp()
print(f"Temporary directory created at: {TEMP_DIR}")
def save_temp_image(pil_img):
path = os.path.join(TEMP_DIR, f"temp_{int(time.time())}.png")
pil_img.save(path, format="PNG")
return path
def handle_image_upload(file, history):
if file is None:
return None, history
file_path = file.name if hasattr(file, "name") else file
pil_img = Image.open(file_path)
saved_path = save_temp_image(pil_img)
return saved_path, history + [((saved_path,), None)]
def clear_all():
for file in os.listdir(TEMP_DIR):
path = os.path.join(TEMP_DIR, file)
try:
if os.path.isfile(path):
os.remove(path)
except Exception as e:
print(f"Failed to delete temp file: {path}, error: {e}")
return [], None, "Language Output"
def extract_assistant_reply(full_text):
if "assistant" in full_text:
parts = full_text.strip().split("assistant")
return parts[-1].lstrip(":").strip()
return full_text.replace("<|im_end|>", "").strip()
def on_submit(history, user_msg, img_path, mode, infer_steps, cfg_scale, model_version='sd3.5-512', seed=42):
user_msg = user_msg.strip()
updated_history = history + [(user_msg, None)]
edit_tip = "✅ You can continue editing this image by switching the mode to 'Edit Image' and entering your instruction. "
# seed = int(seed)
# set_seed(seed) # 设置随机种子以确保可重复性
try:
if mode == "Understand Image":
if img_path is None:
updated_history.append([None, "⚠️ Please upload or generate an image first."])
return updated_history, "", img_path
raw_output = (
inferencer.understand_image(Image.open(img_path), user_msg)
if img_path
else inferencer.query_text(user_msg)
)
clean_output = raw_output
return (
updated_history + [(None, clean_output)],
"",
img_path,
) # 保持 img_path 不变
if mode == "Generate Image":
if not user_msg:
return (
updated_history
+ [(None, "⚠️ Please enter your message for generating.")],
"",
img_path,
)
imgs = inferencer.generate_image(user_msg, num_inference_steps=infer_steps, seed=seed, guidance_scale=cfg_scale)
path = save_temp_image(imgs[0])
return (
updated_history
+ [
(None, (path,)),
(
None,
"✅ You can continue editing this image by switching the mode to 'Edit Image' and entering your instruction. ",
),
],
"",
path,
) # 更新 img_state
elif mode == "Edit Image":
if img_path is None:
return (
updated_history
+ [(None, "⚠️ Please upload or generate an image first.")],
"",
img_path,
)
if not user_msg:
return (
updated_history
+ [(None, "⚠️ Please enter your message for editing.")],
"",
img_path,
)
edited_img = inferencer.edit_image(Image.open(img_path), user_msg, num_inference_steps=infer_steps, seed=seed, guidance_scale=cfg_scale)[0]
path = save_temp_image(edited_img)
return (
updated_history
+ [
(None, (path,)),
(
None,
"✅ You can continue editing this image by entering your instruction. ",
),
],
"",
path,
) # 更新 img_state
except Exception as e:
return updated_history + [(None, f"⚠️ Failed to process: {e}")], "", img_path
# 定义CSS样式(兼容旧版 Gradio)
CSS = """
/* 整体布局 */
.gradio-container {
display: flex !important;
flex-direction: column;
height: 100vh;
margin: 0;
padding: 0;
}
/* 让 tab 自适应填满剩余高度 */
.gr-tabs {
flex: 1 1 auto;
display: flex;
flex-direction: column;
}
/* 聊天 tab 主体 */
#tab_item_4 {
display: flex;
flex-direction: column;
flex: 1 1 auto;
overflow: hidden;
padding: 8px;
}
/* Chatbot 区域 */
#chatbot1 {
flex: 0 0 55vh !important;
overflow-y: auto !important;
border: 1px solid #ddd;
border-radius: 8px;
padding: 12px;
margin-bottom: 8px;
}
#chatbot1 img {
max-width: 80vw !important;
height: auto !important;
border-radius: 4px;
}
/* 控制面板外框 */
.control-panel {
border: 1px solid #ddd;
border-radius: 8px;
padding: 8px;
margin-bottom: 8px;
}
/* 控制行:三列不换行 */
.control-row {
display: flex;
align-items: stretch;
gap: 8px;
flex-wrap: nowrap; /* 不换行 */
}
/* 单个控件卡片样式 */
.control-box {
border: 1px solid #ccc;
border-radius: 8px;
padding: 8px;
background: #f9f9f9;
flex: 1;
min-width: 150px;
box-sizing: border-box;
}
/* 强制旧版 Radio 横向排列 */
.control-box .wrap {
display: flex !important;
flex-direction: row !important;
gap: 8px !important;
}
/* 输入区 */
.input-row {
flex: 0 0 auto;
display: flex;
align-items: center;
padding: 8px;
border-top: 1px solid #eee;
background: #fafafa;
}
.textbox-col { flex: 1; }
.upload-col, .clear-col { flex: 0 0 120px; }
.gr-text-input {
width: 100% !important;
border-radius: 18px !important;
padding: 8px 16px !important;
border: 1px solid #ddd !important;
font-size: 16px !important;
}
"""
with gr.Blocks(css=CSS) as demo:
img_state = gr.State(value=None)
with gr.Tabs():
with gr.Tab("Skywork UniPic2-Metaquery", elem_id="tab_item_4"):
chatbot = gr.Chatbot(
elem_id="chatbot1",
show_label=False,
avatar_images=("user.png", "ai.png"),
)
# 控制区域
with gr.Column(elem_classes="control-panel"):
with gr.Row(elem_classes="control-row"):
with gr.Column(elem_classes="control-box", scale=2, min_width=200):
mode_selector = gr.Radio(
choices=["Generate Image", "Edit Image", "Understand Image"],
value="Generate Image",
label="Mode",
interactive=True
)
with gr.Column(elem_classes="control-box", scale=1, min_width=150):
infer_steps = gr.Slider(
label="Sample Steps",
minimum=1,
maximum=200,
value=50,
step=1,
interactive=True,
)
with gr.Column(elem_classes="control-box", scale=1, min_width=150):
cfg_scale = gr.Slider(
label="CFG Scale",
minimum=1.0,
maximum=16.0,
value=3.5,
step=0.5,
interactive=True,
)
# 输入区域
with gr.Row(elem_classes="input-row"):
with gr.Column(elem_classes="textbox-col"):
user_input = gr.Textbox(
placeholder="Type your message here...",
label="prompt",
lines=1,
)
with gr.Column(elem_classes="upload-col"):
image_input = gr.UploadButton(
"📷 Upload Image",
file_types=["image"],
file_count="single",
type="filepath",
)
with gr.Column(elem_classes="clear-col"):
clear_btn = gr.Button("🧹 Clear History")
# 交互绑定
user_input.submit(
on_submit,
inputs=[chatbot, user_input, img_state, mode_selector, infer_steps, cfg_scale],
outputs=[chatbot, user_input, img_state],
)
image_input.upload(
handle_image_upload, inputs=[image_input, chatbot], outputs=[img_state, chatbot]
)
clear_btn.click(clear_all, outputs=[chatbot, img_state, mode_selector])
if __name__ == "__main__":
demo.launch()
# demo.launch(server_name="0.0.0.0", debug=True, server_port=8004)