Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from PIL import Image | |
import os | |
import tempfile | |
import sys | |
import time | |
from accelerate.utils import set_seed | |
from huggingface_hub import snapshot_download | |
import subprocess | |
#os.system("pip install ./flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "./flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"]) | |
def ensure_flash_attn(): | |
try: | |
import flash_attn | |
print("当前 flash-attn 已安装,版本:", flash_attn.__version__) | |
except ImportError: | |
print("未安装 flash-attn") | |
ensure_flash_attn() | |
from inferencer import UniPicV2Inferencer | |
os.environ["HF_HUB_REQUEST_TIMEOUT"] = "60.0" | |
model_path = snapshot_download(repo_id="Skywork/UniPic2-Metaquery-9B") | |
qwen_vl_path = snapshot_download(repo_id="Qwen/Qwen2.5-VL-7B-Instruct") | |
inferencer = UniPicV2Inferencer( | |
model_path=model_path, | |
qwen_vl_path=qwen_vl_path, | |
quant="fp16" | |
) | |
#inferencer.pipeline = inferencer._init_pipeline() | |
TEMP_DIR = tempfile.mkdtemp() | |
print(f"Temporary directory created at: {TEMP_DIR}") | |
def save_temp_image(pil_img): | |
path = os.path.join(TEMP_DIR, f"temp_{int(time.time())}.png") | |
pil_img.save(path, format="PNG") | |
return path | |
def handle_image_upload(file, history): | |
if file is None: | |
return None, history | |
file_path = file.name if hasattr(file, "name") else file | |
pil_img = Image.open(file_path) | |
saved_path = save_temp_image(pil_img) | |
return saved_path, history + [((saved_path,), None)] | |
def clear_all(): | |
for file in os.listdir(TEMP_DIR): | |
path = os.path.join(TEMP_DIR, file) | |
try: | |
if os.path.isfile(path): | |
os.remove(path) | |
except Exception as e: | |
print(f"Failed to delete temp file: {path}, error: {e}") | |
return [], None, "Language Output" | |
def extract_assistant_reply(full_text): | |
if "assistant" in full_text: | |
parts = full_text.strip().split("assistant") | |
return parts[-1].lstrip(":").strip() | |
return full_text.replace("<|im_end|>", "").strip() | |
def on_submit(history, user_msg, img_path, mode, infer_steps, cfg_scale, model_version='sd3.5-512', seed=42): | |
user_msg = user_msg.strip() | |
updated_history = history + [(user_msg, None)] | |
edit_tip = "✅ You can continue editing this image by switching the mode to 'Edit Image' and entering your instruction. " | |
# seed = int(seed) | |
# set_seed(seed) # 设置随机种子以确保可重复性 | |
try: | |
if mode == "Understand Image": | |
if img_path is None: | |
updated_history.append([None, "⚠️ Please upload or generate an image first."]) | |
return updated_history, "", img_path | |
raw_output = ( | |
inferencer.understand_image(Image.open(img_path), user_msg) | |
if img_path | |
else inferencer.query_text(user_msg) | |
) | |
clean_output = raw_output | |
return ( | |
updated_history + [(None, clean_output)], | |
"", | |
img_path, | |
) # 保持 img_path 不变 | |
if mode == "Generate Image": | |
if not user_msg: | |
return ( | |
updated_history | |
+ [(None, "⚠️ Please enter your message for generating.")], | |
"", | |
img_path, | |
) | |
imgs = inferencer.generate_image(user_msg, num_inference_steps=infer_steps, seed=seed, guidance_scale=cfg_scale) | |
path = save_temp_image(imgs[0]) | |
return ( | |
updated_history | |
+ [ | |
(None, (path,)), | |
( | |
None, | |
"✅ You can continue editing this image by switching the mode to 'Edit Image' and entering your instruction. ", | |
), | |
], | |
"", | |
path, | |
) # 更新 img_state | |
elif mode == "Edit Image": | |
if img_path is None: | |
return ( | |
updated_history | |
+ [(None, "⚠️ Please upload or generate an image first.")], | |
"", | |
img_path, | |
) | |
if not user_msg: | |
return ( | |
updated_history | |
+ [(None, "⚠️ Please enter your message for editing.")], | |
"", | |
img_path, | |
) | |
edited_img = inferencer.edit_image(Image.open(img_path), user_msg, num_inference_steps=infer_steps, seed=seed, guidance_scale=cfg_scale)[0] | |
path = save_temp_image(edited_img) | |
return ( | |
updated_history | |
+ [ | |
(None, (path,)), | |
( | |
None, | |
"✅ You can continue editing this image by entering your instruction. ", | |
), | |
], | |
"", | |
path, | |
) # 更新 img_state | |
except Exception as e: | |
return updated_history + [(None, f"⚠️ Failed to process: {e}")], "", img_path | |
# 定义CSS样式(兼容旧版 Gradio) | |
CSS = """ | |
/* 整体布局 */ | |
.gradio-container { | |
display: flex !important; | |
flex-direction: column; | |
height: 100vh; | |
margin: 0; | |
padding: 0; | |
} | |
/* 让 tab 自适应填满剩余高度 */ | |
.gr-tabs { | |
flex: 1 1 auto; | |
display: flex; | |
flex-direction: column; | |
} | |
/* 聊天 tab 主体 */ | |
#tab_item_4 { | |
display: flex; | |
flex-direction: column; | |
flex: 1 1 auto; | |
overflow: hidden; | |
padding: 8px; | |
} | |
/* Chatbot 区域 */ | |
#chatbot1 { | |
flex: 0 0 55vh !important; | |
overflow-y: auto !important; | |
border: 1px solid #ddd; | |
border-radius: 8px; | |
padding: 12px; | |
margin-bottom: 8px; | |
} | |
#chatbot1 img { | |
max-width: 80vw !important; | |
height: auto !important; | |
border-radius: 4px; | |
} | |
/* 控制面板外框 */ | |
.control-panel { | |
border: 1px solid #ddd; | |
border-radius: 8px; | |
padding: 8px; | |
margin-bottom: 8px; | |
} | |
/* 控制行:三列不换行 */ | |
.control-row { | |
display: flex; | |
align-items: stretch; | |
gap: 8px; | |
flex-wrap: nowrap; /* 不换行 */ | |
} | |
/* 单个控件卡片样式 */ | |
.control-box { | |
border: 1px solid #ccc; | |
border-radius: 8px; | |
padding: 8px; | |
background: #f9f9f9; | |
flex: 1; | |
min-width: 150px; | |
box-sizing: border-box; | |
} | |
/* 强制旧版 Radio 横向排列 */ | |
.control-box .wrap { | |
display: flex !important; | |
flex-direction: row !important; | |
gap: 8px !important; | |
} | |
/* 输入区 */ | |
.input-row { | |
flex: 0 0 auto; | |
display: flex; | |
align-items: center; | |
padding: 8px; | |
border-top: 1px solid #eee; | |
background: #fafafa; | |
} | |
.textbox-col { flex: 1; } | |
.upload-col, .clear-col { flex: 0 0 120px; } | |
.gr-text-input { | |
width: 100% !important; | |
border-radius: 18px !important; | |
padding: 8px 16px !important; | |
border: 1px solid #ddd !important; | |
font-size: 16px !important; | |
} | |
""" | |
with gr.Blocks(css=CSS) as demo: | |
img_state = gr.State(value=None) | |
with gr.Tabs(): | |
with gr.Tab("Skywork UniPic2-Metaquery", elem_id="tab_item_4"): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot1", | |
show_label=False, | |
avatar_images=("user.png", "ai.png"), | |
) | |
# 控制区域 | |
with gr.Column(elem_classes="control-panel"): | |
with gr.Row(elem_classes="control-row"): | |
with gr.Column(elem_classes="control-box", scale=2, min_width=200): | |
mode_selector = gr.Radio( | |
choices=["Generate Image", "Edit Image", "Understand Image"], | |
value="Generate Image", | |
label="Mode", | |
interactive=True | |
) | |
with gr.Column(elem_classes="control-box", scale=1, min_width=150): | |
infer_steps = gr.Slider( | |
label="Sample Steps", | |
minimum=1, | |
maximum=200, | |
value=50, | |
step=1, | |
interactive=True, | |
) | |
with gr.Column(elem_classes="control-box", scale=1, min_width=150): | |
cfg_scale = gr.Slider( | |
label="CFG Scale", | |
minimum=1.0, | |
maximum=16.0, | |
value=3.5, | |
step=0.5, | |
interactive=True, | |
) | |
# 输入区域 | |
with gr.Row(elem_classes="input-row"): | |
with gr.Column(elem_classes="textbox-col"): | |
user_input = gr.Textbox( | |
placeholder="Type your message here...", | |
label="prompt", | |
lines=1, | |
) | |
with gr.Column(elem_classes="upload-col"): | |
image_input = gr.UploadButton( | |
"📷 Upload Image", | |
file_types=["image"], | |
file_count="single", | |
type="filepath", | |
) | |
with gr.Column(elem_classes="clear-col"): | |
clear_btn = gr.Button("🧹 Clear History") | |
# 交互绑定 | |
user_input.submit( | |
on_submit, | |
inputs=[chatbot, user_input, img_state, mode_selector, infer_steps, cfg_scale], | |
outputs=[chatbot, user_input, img_state], | |
) | |
image_input.upload( | |
handle_image_upload, inputs=[image_input, chatbot], outputs=[img_state, chatbot] | |
) | |
clear_btn.click(clear_all, outputs=[chatbot, img_state, mode_selector]) | |
if __name__ == "__main__": | |
demo.launch() | |
# demo.launch(server_name="0.0.0.0", debug=True, server_port=8004) | |