yichenchenchen commited on
Commit
40888ec
·
verified ·
1 Parent(s): 7d3552a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +302 -0
app.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import os
5
+ import tempfile
6
+ import sys
7
+ import time
8
+ from inferencer_metaquery import UniPicV2Inferencer
9
+ from accelerate.utils import set_seed
10
+ from huggingface_hub import snapshot_download
11
+
12
+ model_path = snapshot_download(repo_id="Skywork/UniPic2-Metaquery-9B")
13
+ qwen_vl_path = snapshot_download(repo_id="Qwen/Qwen2.5-VL-7B-Instruct-AWQ")
14
+
15
+ inferencer = UniPicV2Inferencer(
16
+ model_path=model_path,
17
+ qwen_vl_path=qwen_vl_path,
18
+ quant="int4"
19
+ )
20
+
21
+ TEMP_DIR = tempfile.mkdtemp()
22
+ print(f"Temporary directory created at: {TEMP_DIR}")
23
+
24
+ def save_temp_image(pil_img):
25
+ path = os.path.join(TEMP_DIR, f"temp_{int(time.time())}.png")
26
+ pil_img.save(path, format="PNG")
27
+ return path
28
+
29
+
30
+ def handle_image_upload(file, history):
31
+ if file is None:
32
+ return None, history
33
+ file_path = file.name if hasattr(file, "name") else file
34
+ pil_img = Image.open(file_path)
35
+ saved_path = save_temp_image(pil_img)
36
+ return saved_path, history + [((saved_path,), None)]
37
+
38
+
39
+ def clear_all():
40
+ for file in os.listdir(TEMP_DIR):
41
+ path = os.path.join(TEMP_DIR, file)
42
+ try:
43
+ if os.path.isfile(path):
44
+ os.remove(path)
45
+ except Exception as e:
46
+ print(f"Failed to delete temp file: {path}, error: {e}")
47
+ return [], None, "Language Output"
48
+
49
+
50
+ def extract_assistant_reply(full_text):
51
+ if "assistant" in full_text:
52
+ parts = full_text.strip().split("assistant")
53
+ return parts[-1].lstrip(":").strip()
54
+ return full_text.replace("<|im_end|>", "").strip()
55
+
56
+
57
+ def on_submit(history, user_msg, img_path, mode, infer_steps, cfg_scale, model_version='sd3.5-512', seed=3):
58
+ user_msg = user_msg.strip()
59
+ updated_history = history + [(user_msg, None)]
60
+ edit_tip = "✅ You can continue editing this image by switching the mode to 'Edit Image' and entering your instruction. "
61
+ # seed = int(seed)
62
+ # set_seed(seed) # 设置随机种子以确保可重复性
63
+
64
+ # try:
65
+ if mode == "Understand Image":
66
+ if img_path is None:
67
+ updated_history.append([None, "⚠️ Please upload or generate an image first."])
68
+ return updated_history, "", img_path
69
+
70
+ raw_output = (
71
+ inferencer.understand_image(Image.open(img_path), user_msg)
72
+ if img_path
73
+ else inferencer.query_text(user_msg)
74
+ )
75
+ clean_output = raw_output
76
+ return (
77
+ updated_history + [(None, clean_output)],
78
+ "",
79
+ img_path,
80
+ ) # 保持 img_path 不变
81
+
82
+ if mode == "Generate Image":
83
+ if not user_msg:
84
+ return (
85
+ updated_history
86
+ + [(None, "⚠️ Please enter your message for generating.")],
87
+ "",
88
+ img_path,
89
+ )
90
+ imgs = inferencer.generate_image(user_msg, num_inference_steps=infer_steps, seed=seed, guidance_scale=cfg_scale)
91
+ path = save_temp_image(imgs[0])
92
+ return (
93
+ updated_history
94
+ + [
95
+ (None, (path,)),
96
+ (
97
+ None,
98
+ "✅ You can continue editing this image by switching the mode to 'Edit Image' and entering your instruction. ",
99
+ ),
100
+ ],
101
+ "",
102
+ path,
103
+ ) # 更新 img_state
104
+
105
+ elif mode == "Edit Image":
106
+ if img_path is None:
107
+ return (
108
+ updated_history
109
+ + [(None, "⚠️ Please upload or generate an image first.")],
110
+ "",
111
+ img_path,
112
+ )
113
+ if not user_msg:
114
+ return (
115
+ updated_history
116
+ + [(None, "⚠️ Please enter your message for editing.")],
117
+ "",
118
+ img_path,
119
+ )
120
+ edited_img = inferencer.edit_image(Image.open(img_path), user_msg, num_inference_steps=infer_steps, seed=seed, guidance_scale=cfg_scale)[0]
121
+ path = save_temp_image(edited_img)
122
+ return (
123
+ updated_history
124
+ + [
125
+ (None, (path,)),
126
+ (
127
+ None,
128
+ "✅ You can continue editing this image by entering your instruction. ",
129
+ ),
130
+ ],
131
+ "",
132
+ path,
133
+ ) # 更新 img_state
134
+
135
+ # except Exception as e:
136
+ # return updated_history + [(None, f"⚠️ Failed to process: {e}")], "", img_path
137
+
138
+
139
+ # 定义CSS样式(兼容旧版 Gradio)
140
+ CSS = """
141
+ /* 整体布局 */
142
+ .gradio-container {
143
+ display: flex !important;
144
+ flex-direction: column;
145
+ height: 100vh;
146
+ margin: 0;
147
+ padding: 0;
148
+ }
149
+ /* 让 tab 自适应填满剩余高度 */
150
+ .gr-tabs {
151
+ flex: 1 1 auto;
152
+ display: flex;
153
+ flex-direction: column;
154
+ }
155
+ /* 聊天 tab 主体 */
156
+ #tab_item_4 {
157
+ display: flex;
158
+ flex-direction: column;
159
+ flex: 1 1 auto;
160
+ overflow: hidden;
161
+ padding: 8px;
162
+ }
163
+ /* Chatbot 区域 */
164
+ #chatbot1 {
165
+ flex: 0 0 55vh !important;
166
+ overflow-y: auto !important;
167
+ border: 1px solid #ddd;
168
+ border-radius: 8px;
169
+ padding: 12px;
170
+ margin-bottom: 8px;
171
+ }
172
+ #chatbot1 img {
173
+ max-width: 80vw !important;
174
+ height: auto !important;
175
+ border-radius: 4px;
176
+ }
177
+ /* 控制面板外框 */
178
+ .control-panel {
179
+ border: 1px solid #ddd;
180
+ border-radius: 8px;
181
+ padding: 8px;
182
+ margin-bottom: 8px;
183
+ }
184
+ /* 控制行:三列不换行 */
185
+ .control-row {
186
+ display: flex;
187
+ align-items: stretch;
188
+ gap: 8px;
189
+ flex-wrap: nowrap; /* 不换行 */
190
+ }
191
+ /* 单个控件卡片样式 */
192
+ .control-box {
193
+ border: 1px solid #ccc;
194
+ border-radius: 8px;
195
+ padding: 8px;
196
+ background: #f9f9f9;
197
+ flex: 1;
198
+ min-width: 150px;
199
+ box-sizing: border-box;
200
+ }
201
+ /* 强制旧版 Radio 横向排列 */
202
+ .control-box .wrap {
203
+ display: flex !important;
204
+ flex-direction: row !important;
205
+ gap: 8px !important;
206
+ }
207
+
208
+ /* 输入区 */
209
+ .input-row {
210
+ flex: 0 0 auto;
211
+ display: flex;
212
+ align-items: center;
213
+ padding: 8px;
214
+ border-top: 1px solid #eee;
215
+ background: #fafafa;
216
+ }
217
+ .textbox-col { flex: 1; }
218
+ .upload-col, .clear-col { flex: 0 0 120px; }
219
+ .gr-text-input {
220
+ width: 100% !important;
221
+ border-radius: 18px !important;
222
+ padding: 8px 16px !important;
223
+ border: 1px solid #ddd !important;
224
+ font-size: 16px !important;
225
+ }
226
+ """
227
+
228
+ with gr.Blocks(css=CSS) as demo:
229
+ img_state = gr.State(value=None)
230
+
231
+ with gr.Tabs():
232
+ with gr.Tab("Skywork UniPic2-SD3.5M-Kontext", elem_id="tab_item_4"):
233
+ chatbot = gr.Chatbot(
234
+ elem_id="chatbot1",
235
+ show_label=False,
236
+ avatar_images=("user.png", "ai.png"),
237
+ )
238
+
239
+ # 控制区域
240
+ with gr.Column(elem_classes="control-panel"):
241
+ with gr.Row(elem_classes="control-row"):
242
+ with gr.Column(elem_classes="control-box", scale=2, min_width=200):
243
+ mode_selector = gr.Radio(
244
+ choices=["Generate Image", "Edit Image", "Understand Image"],
245
+ value="Generate Image",
246
+ label="Mode",
247
+ interactive=True
248
+ )
249
+ with gr.Column(elem_classes="control-box", scale=1, min_width=150):
250
+ infer_steps = gr.Slider(
251
+ label="Sample Steps",
252
+ minimum=1,
253
+ maximum=200,
254
+ value=50,
255
+ step=1,
256
+ interactive=True,
257
+ )
258
+ with gr.Column(elem_classes="control-box", scale=1, min_width=150):
259
+ cfg_scale = gr.Slider(
260
+ label="CFG Scale",
261
+ minimum=1.0,
262
+ maximum=16.0,
263
+ value=3.5,
264
+ step=0.5,
265
+ interactive=True,
266
+ )
267
+
268
+ # 输入区域
269
+ with gr.Row(elem_classes="input-row"):
270
+ with gr.Column(elem_classes="textbox-col"):
271
+ user_input = gr.Textbox(
272
+ placeholder="Type your message here...",
273
+ label="prompt",
274
+ lines=1,
275
+ )
276
+ with gr.Column(elem_classes="upload-col"):
277
+ image_input = gr.UploadButton(
278
+ "📷 Upload Image",
279
+ file_types=["image"],
280
+ file_count="single",
281
+ type="filepath",
282
+ )
283
+ with gr.Column(elem_classes="clear-col"):
284
+ clear_btn = gr.Button("🧹 Clear History")
285
+
286
+ # 交互绑定
287
+ user_input.submit(
288
+ on_submit,
289
+ inputs=[chatbot, user_input, img_state, mode_selector, infer_steps, cfg_scale],
290
+ outputs=[chatbot, user_input, img_state],
291
+ )
292
+ image_input.upload(
293
+ handle_image_upload, inputs=[image_input, chatbot], outputs=[img_state, chatbot]
294
+ )
295
+ clear_btn.click(clear_all, outputs=[chatbot, img_state, mode_selector])
296
+
297
+ demo.launch()
298
+
299
+
300
+
301
+ #if __name__ == "__main__":
302
+ # demo.launch(server_name="0.0.0.0", debug=True, server_port=8004)