|
import gradio as gr |
|
from PIL import Image |
|
import os |
|
import tempfile |
|
import sys |
|
import time |
|
from inferencer import Inferencer |
|
from accelerate.utils import set_seed |
|
from huggingface_hub import snapshot_download |
|
|
|
model_dir = snapshot_download(repo_id="Skywork/Skywork-UniPic-1.5B") |
|
model_path = os.path.join(model_dir,"pytorch_model.bin") |
|
ckpt_name = "UniPic" |
|
|
|
inferencer = Inferencer( |
|
config_file="qwen2_5_1_5b_kl16_mar_h.py", |
|
model_path=model_path, |
|
image_size=1024, |
|
|
|
) |
|
|
|
TEMP_DIR = tempfile.mkdtemp() |
|
print(f"Temporary directory created at: {TEMP_DIR}") |
|
|
|
def save_temp_image(pil_img): |
|
|
|
|
|
path = os.path.join(TEMP_DIR, f"temp_{int(time.time())}.png") |
|
pil_img.save(path, format="PNG") |
|
return path |
|
|
|
def handle_image_upload(file, history): |
|
if file is None: |
|
return None, history |
|
file_path = file.name if hasattr(file, "name") else file |
|
pil_img = Image.open(file_path) |
|
saved_path = save_temp_image(pil_img) |
|
return saved_path, history + [((saved_path,), None)] |
|
|
|
def clear_all(): |
|
for file in os.listdir(TEMP_DIR): |
|
path = os.path.join(TEMP_DIR, file) |
|
try: |
|
if os.path.isfile(path): |
|
os.remove(path) |
|
except Exception as e: |
|
print(f"Failed to delete temp file: {path}, error: {e}") |
|
return [], None, "Understand Image" |
|
|
|
def extract_assistant_reply(full_text): |
|
if "assistant" in full_text: |
|
parts = full_text.strip().split("assistant") |
|
return parts[-1].lstrip(":").strip() |
|
return full_text.replace("<|im_end|>", "").strip() |
|
|
|
def on_submit(history, user_msg, img_path, mode, grid_size=1): |
|
|
|
updated_history = [list(item) for item in history] |
|
user_msg = user_msg.strip() |
|
updated_history.append([user_msg, None]) |
|
|
|
|
|
try: |
|
if mode == "Understand Image": |
|
if img_path is None: |
|
updated_history.append([None, "⚠️ Please upload or generate an image first."]) |
|
return updated_history, "", img_path |
|
|
|
raw = ( |
|
inferencer.query_image(Image.open(img_path), user_msg) |
|
if img_path else inferencer.query_text(user_msg) |
|
) |
|
reply = extract_assistant_reply(raw) |
|
updated_history.append([None, reply]) |
|
return updated_history, "", img_path |
|
|
|
elif mode == "Generate Image": |
|
if not user_msg: |
|
updated_history.append([None, "⚠️ Please enter a prompt."]) |
|
return updated_history, "", img_path |
|
|
|
imgs = inferencer.gen_image( |
|
raw_prompt=user_msg, |
|
images_to_generate=grid_size**2, |
|
cfg=3.0, |
|
num_iter=48, |
|
cfg_schedule="constant", |
|
temperature=1.0, |
|
) |
|
paths = [save_temp_image(img) for img in imgs] |
|
|
|
updated_history.append([None, paths]) |
|
return updated_history, "", paths[-1] |
|
|
|
elif mode == "Edit Image": |
|
if img_path is None: |
|
updated_history.append([None, "⚠️ Please upload or generate an image first."]) |
|
return updated_history, "", img_path |
|
if not user_msg: |
|
updated_history.append([None, "⚠️ Please enter an edit instruction."]) |
|
return updated_history, "", img_path |
|
|
|
img = Image.open(img_path) |
|
|
|
imgs = inferencer.edit_image( |
|
source_image=img, |
|
prompt=user_msg, |
|
cfg=3.0, |
|
cfg_prompt="repeat this image.", |
|
cfg_schedule="constant", |
|
temperature=0.85, |
|
grid_size=grid_size, |
|
num_iter=48, |
|
) |
|
paths = [save_temp_image(img) for img in imgs] |
|
updated_history.append([None, paths]) |
|
return updated_history, "", paths[-1] |
|
|
|
except Exception as e: |
|
updated_history.append([None, f"⚠️ Failed to process: {e}"]) |
|
return updated_history, "", img_path |
|
|
|
CSS = """ |
|
/* 整体布局:上下两块 */ |
|
.gradio-container { |
|
display: flex !important; |
|
flex-direction: column; |
|
height: 100vh; |
|
margin: 0; |
|
padding: 0; |
|
} |
|
.gr-tabs { /* ✅ 新增:确保 tab 能继承高度 */ |
|
flex: 1 1 auto; |
|
display: flex; |
|
flex-direction: column; |
|
} |
|
|
|
/* 聊天 tab */ |
|
#tab_item_4, #tab_item_5 { |
|
display: flex; |
|
flex-direction: column; |
|
flex: 1 1 auto; |
|
overflow: hidden; /* 防止出现双滚动条 */ |
|
padding: 8px; |
|
} |
|
|
|
/* Chatbot 撑满 */ |
|
#chatbot1, #chatbot2{ |
|
flex-grow: 1 !important; |
|
max-height: 66vh !important; /* 限制聊天框最大高度为屏幕的2/3 */ |
|
overflow-y: auto !important; /* 当内容溢出时显示滚动条 */ |
|
border: 1px solid #ddd; |
|
border-radius: 8px; |
|
padding: 12px; |
|
margin-bottom: 8px; |
|
} |
|
|
|
/* 图片消息放大 */ |
|
#chatbot1 img, #chatbot2 img { |
|
max-width: 80vw !important; |
|
height: auto !important; |
|
border-radius: 4px; |
|
} |
|
|
|
/* 底部输入区:固定高度 */ |
|
.input-row { |
|
flex: 0 0 auto; |
|
display: flex; |
|
align-items: center; |
|
padding: 8px; |
|
border-top: 1px solid #eee; |
|
background: #fafafa; |
|
} |
|
|
|
/* 文本框和按钮排布 */ |
|
.input-row .textbox-col { flex: 5; } |
|
.input-row .upload-col, .input-row .clear-col { flex: 1; margin-left: 8px; } |
|
|
|
/* 文本框样式 */ |
|
.gr-text-input { |
|
width: 100% !important; |
|
border-radius: 18px !important; |
|
padding: 8px 16px !important; |
|
border: 1px solid #ddd !important; |
|
font-size: 16px !important; |
|
} |
|
|
|
/* 按钮和上传组件样式 */ |
|
.gr-button, .gr-upload { |
|
width: 100% !important; |
|
border-radius: 18px !important; |
|
padding: 8px 16px !important; |
|
font-size: 16px !important; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=CSS) as demo: |
|
img_state = gr.State(value=None) |
|
mode_state = gr.State(value="Understand Image") |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("Skywork UniPic Chatbot", elem_id="tab_item_4"): |
|
chatbot = gr.Chatbot( |
|
elem_id="chatbot1", |
|
show_label=False, |
|
avatar_images=( |
|
"user.png", |
|
"ai.png", |
|
), |
|
) |
|
with gr.Row(): |
|
mode_selector = gr.Radio( |
|
choices=["Generate Image","Edit Image","Understand Image"], |
|
value="Generate Image", |
|
label="Mode", |
|
interactive=True, |
|
) |
|
|
|
with gr.Row(elem_classes="input-row"): |
|
with gr.Column(elem_classes="textbox-col"): |
|
user_input = gr.Textbox( |
|
placeholder="Type your message here...", |
|
show_label=False, |
|
lines=1, |
|
) |
|
with gr.Column(elem_classes="upload-col"): |
|
image_input = gr.UploadButton( |
|
"📷 Upload Image", |
|
file_types=["image"], |
|
file_count="single", |
|
type="filepath", |
|
) |
|
with gr.Column(elem_classes="clear-col"): |
|
clear_btn = gr.Button("🧹 Clear History") |
|
|
|
user_input.submit( |
|
on_submit, |
|
[chatbot, user_input, img_state, mode_selector], |
|
[chatbot, user_input, img_state], |
|
) |
|
|
|
image_input.upload( |
|
handle_image_upload, [image_input, chatbot], [img_state, chatbot] |
|
) |
|
clear_btn.click(clear_all, outputs=[chatbot, img_state, mode_selector]) |
|
|
|
|
|
|
|
demo.launch() |