yichenchenchen's picture
Update app.py
001e7cd verified
raw
history blame
9.77 kB
import spaces
import gradio as gr
from PIL import Image
import os
import tempfile
import sys
import time
from accelerate.utils import set_seed
from huggingface_hub import snapshot_download
import subprocess
#os.system("pip install ./flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")
subprocess.check_call([sys.executable, "-m", "pip", "install", "./flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"])
def ensure_flash_attn():
try:
import flash_attn
print("当前 flash-attn 已安装,版本:", flash_attn.__version__)
except ImportError:
print("未安装 flash-attn,开始安装...")
ensure_flash_attn()
from inferencer import UniPicV2Inferencer
model_path = snapshot_download(repo_id="Skywork/UniPic2-Metaquery-9B")
qwen_vl_path = snapshot_download(repo_id="Qwen/Qwen2.5-VL-7B-Instruct-AWQ")
inferencer = UniPicV2Inferencer(
model_path=model_path,
qwen_vl_path=qwen_vl_path,
quant="int4"
)
#inferencer.pipeline = inferencer._init_pipeline()
TEMP_DIR = tempfile.mkdtemp()
print(f"Temporary directory created at: {TEMP_DIR}")
def save_temp_image(pil_img):
path = os.path.join(TEMP_DIR, f"temp_{int(time.time())}.png")
pil_img.save(path, format="PNG")
return path
def handle_image_upload(file, history):
if file is None:
return None, history
file_path = file.name if hasattr(file, "name") else file
pil_img = Image.open(file_path)
saved_path = save_temp_image(pil_img)
return saved_path, history + [((saved_path,), None)]
def clear_all():
for file in os.listdir(TEMP_DIR):
path = os.path.join(TEMP_DIR, file)
try:
if os.path.isfile(path):
os.remove(path)
except Exception as e:
print(f"Failed to delete temp file: {path}, error: {e}")
return [], None, "Language Output"
def extract_assistant_reply(full_text):
if "assistant" in full_text:
parts = full_text.strip().split("assistant")
return parts[-1].lstrip(":").strip()
return full_text.replace("<|im_end|>", "").strip()
def on_submit(history, user_msg, img_path, mode, infer_steps, cfg_scale, model_version='sd3.5-512', seed=3):
user_msg = user_msg.strip()
updated_history = history + [(user_msg, None)]
edit_tip = "✅ You can continue editing this image by switching the mode to 'Edit Image' and entering your instruction. "
# seed = int(seed)
# set_seed(seed) # 设置随机种子以确保可重复性
# try:
if mode == "Understand Image":
if img_path is None:
updated_history.append([None, "⚠️ Please upload or generate an image first."])
return updated_history, "", img_path
raw_output = (
inferencer.understand_image(Image.open(img_path), user_msg)
if img_path
else inferencer.query_text(user_msg)
)
clean_output = raw_output
return (
updated_history + [(None, clean_output)],
"",
img_path,
) # 保持 img_path 不变
if mode == "Generate Image":
if not user_msg:
return (
updated_history
+ [(None, "⚠️ Please enter your message for generating.")],
"",
img_path,
)
imgs = inferencer.generate_image(user_msg, num_inference_steps=infer_steps, seed=seed, guidance_scale=cfg_scale)
path = save_temp_image(imgs[0])
return (
updated_history
+ [
(None, (path,)),
(
None,
"✅ You can continue editing this image by switching the mode to 'Edit Image' and entering your instruction. ",
),
],
"",
path,
) # 更新 img_state
elif mode == "Edit Image":
if img_path is None:
return (
updated_history
+ [(None, "⚠️ Please upload or generate an image first.")],
"",
img_path,
)
if not user_msg:
return (
updated_history
+ [(None, "⚠️ Please enter your message for editing.")],
"",
img_path,
)
edited_img = inferencer.edit_image(Image.open(img_path), user_msg, num_inference_steps=infer_steps, seed=seed, guidance_scale=cfg_scale)[0]
path = save_temp_image(edited_img)
return (
updated_history
+ [
(None, (path,)),
(
None,
"✅ You can continue editing this image by entering your instruction. ",
),
],
"",
path,
) # 更新 img_state
# except Exception as e:
# return updated_history + [(None, f"⚠️ Failed to process: {e}")], "", img_path
# 定义CSS样式(兼容旧版 Gradio)
CSS = """
/* 整体布局 */
.gradio-container {
display: flex !important;
flex-direction: column;
height: 100vh;
margin: 0;
padding: 0;
}
/* 让 tab 自适应填满剩余高度 */
.gr-tabs {
flex: 1 1 auto;
display: flex;
flex-direction: column;
}
/* 聊天 tab 主体 */
#tab_item_4 {
display: flex;
flex-direction: column;
flex: 1 1 auto;
overflow: hidden;
padding: 8px;
}
/* Chatbot 区域 */
#chatbot1 {
flex: 0 0 55vh !important;
overflow-y: auto !important;
border: 1px solid #ddd;
border-radius: 8px;
padding: 12px;
margin-bottom: 8px;
}
#chatbot1 img {
max-width: 80vw !important;
height: auto !important;
border-radius: 4px;
}
/* 控制面板外框 */
.control-panel {
border: 1px solid #ddd;
border-radius: 8px;
padding: 8px;
margin-bottom: 8px;
}
/* 控制行:三列不换行 */
.control-row {
display: flex;
align-items: stretch;
gap: 8px;
flex-wrap: nowrap; /* 不换行 */
}
/* 单个控件卡片样式 */
.control-box {
border: 1px solid #ccc;
border-radius: 8px;
padding: 8px;
background: #f9f9f9;
flex: 1;
min-width: 150px;
box-sizing: border-box;
}
/* 强制旧版 Radio 横向排列 */
.control-box .wrap {
display: flex !important;
flex-direction: row !important;
gap: 8px !important;
}
/* 输入区 */
.input-row {
flex: 0 0 auto;
display: flex;
align-items: center;
padding: 8px;
border-top: 1px solid #eee;
background: #fafafa;
}
.textbox-col { flex: 1; }
.upload-col, .clear-col { flex: 0 0 120px; }
.gr-text-input {
width: 100% !important;
border-radius: 18px !important;
padding: 8px 16px !important;
border: 1px solid #ddd !important;
font-size: 16px !important;
}
"""
with gr.Blocks(css=CSS) as demo:
img_state = gr.State(value=None)
with gr.Tabs():
with gr.Tab("Skywork UniPic2-Metaquery int4", elem_id="tab_item_4"):
chatbot = gr.Chatbot(
elem_id="chatbot1",
show_label=False,
avatar_images=("user.png", "ai.png"),
)
# 控制区域
with gr.Column(elem_classes="control-panel"):
with gr.Row(elem_classes="control-row"):
with gr.Column(elem_classes="control-box", scale=2, min_width=200):
mode_selector = gr.Radio(
choices=["Generate Image", "Edit Image", "Understand Image"],
value="Generate Image",
label="Mode",
interactive=True
)
with gr.Column(elem_classes="control-box", scale=1, min_width=150):
infer_steps = gr.Slider(
label="Sample Steps",
minimum=1,
maximum=200,
value=50,
step=1,
interactive=True,
)
with gr.Column(elem_classes="control-box", scale=1, min_width=150):
cfg_scale = gr.Slider(
label="CFG Scale",
minimum=1.0,
maximum=16.0,
value=3.5,
step=0.5,
interactive=True,
)
# 输入区域
with gr.Row(elem_classes="input-row"):
with gr.Column(elem_classes="textbox-col"):
user_input = gr.Textbox(
placeholder="Type your message here...",
label="prompt",
lines=1,
)
with gr.Column(elem_classes="upload-col"):
image_input = gr.UploadButton(
"📷 Upload Image",
file_types=["image"],
file_count="single",
type="filepath",
)
with gr.Column(elem_classes="clear-col"):
clear_btn = gr.Button("🧹 Clear History")
# 交互绑定
user_input.submit(
on_submit,
inputs=[chatbot, user_input, img_state, mode_selector, infer_steps, cfg_scale],
outputs=[chatbot, user_input, img_state],
)
image_input.upload(
handle_image_upload, inputs=[image_input, chatbot], outputs=[img_state, chatbot]
)
clear_btn.click(clear_all, outputs=[chatbot, img_state, mode_selector])
demo.launch()
#if __name__ == "__main__":
# demo.launch(server_name="0.0.0.0", debug=True, server_port=8004)