import spaces import gradio as gr from PIL import Image import os import tempfile import sys import time from accelerate.utils import set_seed from huggingface_hub import snapshot_download import subprocess #os.system("pip install ./flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl") subprocess.check_call([sys.executable, "-m", "pip", "install", "./flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"]) def ensure_flash_attn(): try: import flash_attn print("当前 flash-attn 已安装,版本:", flash_attn.__version__) except ImportError: print("未安装 flash-attn") ensure_flash_attn() from inferencer import UniPicV2Inferencer os.environ["HF_HUB_REQUEST_TIMEOUT"] = "60.0" model_path = snapshot_download(repo_id="Skywork/UniPic2-Metaquery-9B") qwen_vl_path = snapshot_download(repo_id="Qwen/Qwen2.5-VL-7B-Instruct") inferencer = UniPicV2Inferencer( model_path=model_path, qwen_vl_path=qwen_vl_path, quant="fp16" ) #inferencer.pipeline = inferencer._init_pipeline() TEMP_DIR = tempfile.mkdtemp() print(f"Temporary directory created at: {TEMP_DIR}") def save_temp_image(pil_img): path = os.path.join(TEMP_DIR, f"temp_{int(time.time())}.png") pil_img.save(path, format="PNG") return path def handle_image_upload(file, history): if file is None: return None, history file_path = file.name if hasattr(file, "name") else file pil_img = Image.open(file_path) saved_path = save_temp_image(pil_img) return saved_path, history + [((saved_path,), None)] def clear_all(): for file in os.listdir(TEMP_DIR): path = os.path.join(TEMP_DIR, file) try: if os.path.isfile(path): os.remove(path) except Exception as e: print(f"Failed to delete temp file: {path}, error: {e}") return [], None, "Language Output" def extract_assistant_reply(full_text): if "assistant" in full_text: parts = full_text.strip().split("assistant") return parts[-1].lstrip(":").strip() return full_text.replace("<|im_end|>", "").strip() def on_submit(history, user_msg, img_path, mode, infer_steps, cfg_scale, model_version='sd3.5-512', seed=42): user_msg = user_msg.strip() updated_history = history + [(user_msg, None)] edit_tip = "✅ You can continue editing this image by switching the mode to 'Edit Image' and entering your instruction. " # seed = int(seed) # set_seed(seed) # 设置随机种子以确保可重复性 try: if mode == "Understand Image": if img_path is None: updated_history.append([None, "⚠️ Please upload or generate an image first."]) return updated_history, "", img_path raw_output = ( inferencer.understand_image(Image.open(img_path), user_msg) if img_path else inferencer.query_text(user_msg) ) clean_output = raw_output return ( updated_history + [(None, clean_output)], "", img_path, ) # 保持 img_path 不变 if mode == "Generate Image": if not user_msg: return ( updated_history + [(None, "⚠️ Please enter your message for generating.")], "", img_path, ) imgs = inferencer.generate_image(user_msg, num_inference_steps=infer_steps, seed=seed, guidance_scale=cfg_scale) path = save_temp_image(imgs[0]) return ( updated_history + [ (None, (path,)), ( None, "✅ You can continue editing this image by switching the mode to 'Edit Image' and entering your instruction. ", ), ], "", path, ) # 更新 img_state elif mode == "Edit Image": if img_path is None: return ( updated_history + [(None, "⚠️ Please upload or generate an image first.")], "", img_path, ) if not user_msg: return ( updated_history + [(None, "⚠️ Please enter your message for editing.")], "", img_path, ) edited_img = inferencer.edit_image(Image.open(img_path), user_msg, num_inference_steps=infer_steps, seed=seed, guidance_scale=cfg_scale)[0] path = save_temp_image(edited_img) return ( updated_history + [ (None, (path,)), ( None, "✅ You can continue editing this image by entering your instruction. ", ), ], "", path, ) # 更新 img_state except Exception as e: return updated_history + [(None, f"⚠️ Failed to process: {e}")], "", img_path # 定义CSS样式(兼容旧版 Gradio) CSS = """ /* 整体布局 */ .gradio-container { display: flex !important; flex-direction: column; height: 100vh; margin: 0; padding: 0; } /* 让 tab 自适应填满剩余高度 */ .gr-tabs { flex: 1 1 auto; display: flex; flex-direction: column; } /* 聊天 tab 主体 */ #tab_item_4 { display: flex; flex-direction: column; flex: 1 1 auto; overflow: hidden; padding: 8px; } /* Chatbot 区域 */ #chatbot1 { flex: 0 0 55vh !important; overflow-y: auto !important; border: 1px solid #ddd; border-radius: 8px; padding: 12px; margin-bottom: 8px; } #chatbot1 img { max-width: 80vw !important; height: auto !important; border-radius: 4px; } /* 控制面板外框 */ .control-panel { border: 1px solid #ddd; border-radius: 8px; padding: 8px; margin-bottom: 8px; } /* 控制行:三列不换行 */ .control-row { display: flex; align-items: stretch; gap: 8px; flex-wrap: nowrap; /* 不换行 */ } /* 单个控件卡片样式 */ .control-box { border: 1px solid #ccc; border-radius: 8px; padding: 8px; background: #f9f9f9; flex: 1; min-width: 150px; box-sizing: border-box; } /* 强制旧版 Radio 横向排列 */ .control-box .wrap { display: flex !important; flex-direction: row !important; gap: 8px !important; } /* 输入区 */ .input-row { flex: 0 0 auto; display: flex; align-items: center; padding: 8px; border-top: 1px solid #eee; background: #fafafa; } .textbox-col { flex: 1; } .upload-col, .clear-col { flex: 0 0 120px; } .gr-text-input { width: 100% !important; border-radius: 18px !important; padding: 8px 16px !important; border: 1px solid #ddd !important; font-size: 16px !important; } """ with gr.Blocks(css=CSS) as demo: img_state = gr.State(value=None) with gr.Tabs(): with gr.Tab("Skywork UniPic2-Metaquery", elem_id="tab_item_4"): chatbot = gr.Chatbot( elem_id="chatbot1", show_label=False, avatar_images=("user.png", "ai.png"), ) # 控制区域 with gr.Column(elem_classes="control-panel"): with gr.Row(elem_classes="control-row"): with gr.Column(elem_classes="control-box", scale=2, min_width=200): mode_selector = gr.Radio( choices=["Generate Image", "Edit Image", "Understand Image"], value="Generate Image", label="Mode", interactive=True ) with gr.Column(elem_classes="control-box", scale=1, min_width=150): infer_steps = gr.Slider( label="Sample Steps", minimum=1, maximum=200, value=50, step=1, interactive=True, ) with gr.Column(elem_classes="control-box", scale=1, min_width=150): cfg_scale = gr.Slider( label="CFG Scale", minimum=1.0, maximum=16.0, value=3.5, step=0.5, interactive=True, ) # 输入区域 with gr.Row(elem_classes="input-row"): with gr.Column(elem_classes="textbox-col"): user_input = gr.Textbox( placeholder="Type your message here...", label="prompt", lines=1, ) with gr.Column(elem_classes="upload-col"): image_input = gr.UploadButton( "📷 Upload Image", file_types=["image"], file_count="single", type="filepath", ) with gr.Column(elem_classes="clear-col"): clear_btn = gr.Button("🧹 Clear History") # 交互绑定 user_input.submit( on_submit, inputs=[chatbot, user_input, img_state, mode_selector, infer_steps, cfg_scale], outputs=[chatbot, user_input, img_state], ) image_input.upload( handle_image_upload, inputs=[image_input, chatbot], outputs=[img_state, chatbot] ) clear_btn.click(clear_all, outputs=[chatbot, img_state, mode_selector]) if __name__ == "__main__": demo.launch() # demo.launch(server_name="0.0.0.0", debug=True, server_port=8004)