ginipick commited on
Commit
b32c775
ยท
verified ยท
1 Parent(s): 5ad049a

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -802
app.py DELETED
@@ -1,802 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import os
4
- import re
5
- import tempfile
6
- import gc # garbage collector ์ถ”๊ฐ€
7
- from collections.abc import Iterator
8
- from threading import Thread
9
- import json
10
- import requests
11
- import cv2
12
- import base64
13
- import logging
14
- import time
15
- from urllib.parse import quote # URL ์ธ์ฝ”๋”ฉ (ํ•„์š” ์‹œ ์‚ฌ์šฉ)
16
-
17
- import gradio as gr
18
- import spaces
19
- import torch
20
- from loguru import logger
21
- from PIL import Image
22
- from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
23
-
24
- # CSV/TXT/PDF ๋ถ„์„
25
- import pandas as pd
26
- import PyPDF2
27
-
28
- # =============================================================================
29
- # (์‹ ๊ทœ) ์ด๋ฏธ์ง€ API ๊ด€๋ จ ํ•จ์ˆ˜๋“ค
30
- # =============================================================================
31
- from gradio_client import Client
32
-
33
- API_URL = "http://211.233.58.201:7896"
34
-
35
- logging.basicConfig(
36
- level=logging.DEBUG,
37
- format='%(asctime)s - %(levelname)s - %(message)s'
38
- )
39
-
40
- def test_api_connection() -> str:
41
- """API ์„œ๋ฒ„ ์—ฐ๊ฒฐ ํ…Œ์ŠคํŠธ"""
42
- try:
43
- client = Client(API_URL)
44
- return "API ์—ฐ๊ฒฐ ์„ฑ๊ณต: ์ •์ƒ ์ž‘๋™ ์ค‘"
45
- except Exception as e:
46
- logging.error(f"API connection test failed: {e}")
47
- return f"API ์—ฐ๊ฒฐ ์‹คํŒจ: {e}"
48
-
49
- def generate_image(prompt: str, width: float, height: float, guidance: float, inference_steps: float, seed: float):
50
- """
51
- ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜.
52
- ์—ฌ๊ธฐ์„œ๋Š” ์„œ๋ฒ„๊ฐ€ ์ตœ์ข… ์ด๋ฏธ์ง€๋ฅผ Base64(๋˜๋Š” data:image/...) ํ˜•ํƒœ๋กœ ์ง์ ‘ ๋ฐ˜ํ™˜ํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
53
- /tmp/... ๊ฒฝ๋กœ๋‚˜ ์ถ”๊ฐ€ ๋‹ค์šด๋กœ๋“œ๋ฅผ ์‹œ๋„ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
54
- """
55
- if not prompt:
56
- return None, "Error: Prompt is required"
57
- try:
58
- logging.info(f"Calling image generation API with prompt: {prompt}")
59
-
60
- client = Client(API_URL)
61
- result = client.predict(
62
- prompt=prompt,
63
- width=int(width),
64
- height=int(height),
65
- guidance=float(guidance),
66
- inference_steps=int(inference_steps),
67
- seed=int(seed),
68
- do_img2img=False,
69
- init_image=None,
70
- image2image_strength=0.8,
71
- resize_img=True,
72
- api_name="/generate_image"
73
- )
74
-
75
- logging.info(
76
- f"Image generation result: {type(result)}, "
77
- f"length: {len(result) if isinstance(result, (list, tuple)) else 'unknown'}"
78
- )
79
-
80
- # ๊ฒฐ๊ณผ๊ฐ€ ํŠœํ”Œ/๋ฆฌ์ŠคํŠธ: [์ด๋ฏธ์ง€_base64 or data_url, seed_info] ๋กœ ๊ฐ€์ •
81
- if isinstance(result, (list, tuple)) and len(result) > 0:
82
- image_data = result[0] # ์ฒซ ๋ฒˆ์งธ ์š”์†Œ๊ฐ€ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ (Base64 or data:image/... ๋“ฑ)
83
- seed_info = result[1] if len(result) > 1 else "Unknown seed"
84
- return image_data, seed_info
85
- else:
86
- # ๋‹ค๋ฅธ ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜๋œ ๊ฒฝ์šฐ
87
- return result, "Unknown seed"
88
-
89
- except Exception as e:
90
- logging.error(f"Image generation failed: {str(e)}")
91
- return None, f"Error: {str(e)}"
92
-
93
- # Base64 ํŒจ๋”ฉ ์ˆ˜์ • ํ•จ์ˆ˜ (ํ•„์š”ํ•˜๋‹ค๋ฉด ์‚ฌ์šฉ)
94
- def fix_base64_padding(data):
95
- """Base64 ๋ฌธ์ž์—ด์˜ ํŒจ๋”ฉ์„ ์ˆ˜์ •ํ•ฉ๋‹ˆ๋‹ค."""
96
- if isinstance(data, bytes):
97
- data = data.decode('utf-8')
98
-
99
- if "base64," in data:
100
- data = data.split("base64,", 1)[1]
101
-
102
- missing_padding = len(data) % 4
103
- if missing_padding:
104
- data += '=' * (4 - missing_padding)
105
-
106
- return data
107
-
108
- # =============================================================================
109
- # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜
110
- # =============================================================================
111
- def clear_cuda_cache():
112
- """CUDA ์บ์‹œ๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ๋น„์›๋‹ˆ๋‹ค."""
113
- if torch.cuda.is_available():
114
- torch.cuda.empty_cache()
115
- gc.collect()
116
-
117
- # =============================================================================
118
- # SerpHouse ๊ด€๋ จ ํ•จ์ˆ˜
119
- # =============================================================================
120
- SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
121
-
122
- def extract_keywords(text: str, top_k: int = 5) -> str:
123
- """๋‹จ์ˆœ ํ‚ค์›Œ๋“œ ์ถ”์ถœ: ํ•œ๊ธ€, ์˜์–ด, ์ˆซ์ž, ๊ณต๋ฐฑ๋งŒ ๋‚จ๊น€"""
124
- text = re.sub(r"[^a-zA-Z0-9๊ฐ€-ํžฃ\s]", "", text)
125
- tokens = text.split()
126
- return " ".join(tokens[:top_k])
127
-
128
- def do_web_search(query: str) -> str:
129
- """
130
- SerpHouse LIVE API ํ˜ธ์ถœํ•˜์—ฌ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ๋งˆํฌ๋‹ค์šด ๋ฐ˜ํ™˜
131
- (ํ•„์š”ํ•˜๋‹ค๋ฉด ์ˆ˜์ • or ์‚ญ์ œ ๊ฐ€๋Šฅ)
132
- """
133
- try:
134
- url = "https://api.serphouse.com/serp/live"
135
- params = {
136
- "q": query,
137
- "domain": "google.com",
138
- "serp_type": "web",
139
- "device": "desktop",
140
- "lang": "en",
141
- "num": "20"
142
- }
143
- headers = {"Authorization": f"Bearer {SERPHOUSE_API_KEY}"}
144
- logger.info(f"SerpHouse API ํ˜ธ์ถœ ์ค‘... ๊ฒ€์ƒ‰์–ด: {query}")
145
- response = requests.get(url, headers=headers, params=params, timeout=60)
146
- response.raise_for_status()
147
- data = response.json()
148
- results = data.get("results", {})
149
- organic = None
150
- if isinstance(results, dict) and "organic" in results:
151
- organic = results["organic"]
152
- elif isinstance(results, dict) and "results" in results:
153
- if isinstance(results["results"], dict) and "organic" in results["results"]:
154
- organic = results["results"]["organic"]
155
- elif "organic" in data:
156
- organic = data["organic"]
157
- if not organic:
158
- logger.warning("์‘๋‹ต์—์„œ organic ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
159
- return "No web search results found or unexpected API response structure."
160
- max_results = min(20, len(organic))
161
- limited_organic = organic[:max_results]
162
- summary_lines = []
163
- for idx, item in enumerate(limited_organic, start=1):
164
- title = item.get("title", "No title")
165
- link = item.get("link", "#")
166
- snippet = item.get("snippet", "No description")
167
- displayed_link = item.get("displayed_link", link)
168
- summary_lines.append(
169
- f"### Result {idx}: {title}\n\n"
170
- f"{snippet}\n\n"
171
- f"**์ถœ์ฒ˜**: [{displayed_link}]({link})\n\n"
172
- f"---\n"
173
- )
174
- instructions = """
175
- # ์›น ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ
176
- ์•„๋ž˜๋Š” ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค. ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•  ๋•Œ ์ด ์ •๋ณด๋ฅผ ํ™œ์šฉํ•˜์„ธ์š”:
177
- 1. ์—ฌ๋Ÿฌ ์ถœ์ฒ˜ ๋‚ด์šฉ์„ ์ข…ํ•ฉํ•˜์—ฌ ๋‹ต๋ณ€.
178
- 2. ์ถœ์ฒ˜ ์ธ์šฉ ์‹œ "[์ถœ์ฒ˜ ์ œ๋ชฉ](๋งํฌ)" ๋งˆํฌ๋‹ค์šด ํ˜•์‹ ์‚ฌ์šฉ.
179
- 3. ๋‹ต๋ณ€ ๋งˆ์ง€๋ง‰์— '์ฐธ๊ณ  ์ž๋ฃŒ:' ์„น์…˜์— ์‚ฌ์šฉํ•œ ์ฃผ์š” ์ถœ์ฒ˜๋ฅผ ๋‚˜์—ด.
180
- """
181
- return instructions + "\n".join(summary_lines)
182
- except Exception as e:
183
- logger.error(f"Web search failed: {e}")
184
- return f"Web search failed: {str(e)}"
185
-
186
- # =============================================================================
187
- # ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋”ฉ
188
- # =============================================================================
189
- MAX_CONTENT_CHARS = 2000
190
- MAX_INPUT_LENGTH = 2096
191
-
192
- model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
193
- processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
194
- model = Gemma3ForConditionalGeneration.from_pretrained(
195
- model_id,
196
- device_map="auto",
197
- torch_dtype=torch.bfloat16,
198
- attn_implementation="eager"
199
- )
200
-
201
- MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
202
-
203
- # =============================================================================
204
- # CSV, TXT, PDF ๋ถ„์„ ํ•จ์ˆ˜
205
- # =============================================================================
206
- def analyze_csv_file(path: str) -> str:
207
- try:
208
- df = pd.read_csv(path)
209
- if df.shape[0] > 50 or df.shape[1] > 10:
210
- df = df.iloc[:50, :10]
211
- df_str = df.to_string()
212
- if len(df_str) > MAX_CONTENT_CHARS:
213
- df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
214
- return f"**[CSV File: {os.path.basename(path)}]**\n\n{df_str}"
215
- except Exception as e:
216
- return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}"
217
-
218
- def analyze_txt_file(path: str) -> str:
219
- try:
220
- with open(path, "r", encoding="utf-8") as f:
221
- text = f.read()
222
- if len(text) > MAX_CONTENT_CHARS:
223
- text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
224
- return f"**[TXT File: {os.path.basename(path)}]**\n\n{text}"
225
- except Exception as e:
226
- return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}"
227
-
228
- def pdf_to_markdown(pdf_path: str) -> str:
229
- text_chunks = []
230
- try:
231
- with open(pdf_path, "rb") as f:
232
- reader = PyPDF2.PdfReader(f)
233
- max_pages = min(5, len(reader.pages))
234
- for page_num in range(max_pages):
235
- page_text = reader.pages[page_num].extract_text() or ""
236
- page_text = page_text.strip()
237
- if page_text:
238
- if len(page_text) > MAX_CONTENT_CHARS // max_pages:
239
- page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(truncated)"
240
- text_chunks.append(f"## Page {page_num+1}\n\n{page_text}\n")
241
- if len(reader.pages) > max_pages:
242
- text_chunks.append(f"\n...(Showing {max_pages} of {len(reader.pages)} pages)...")
243
- except Exception as e:
244
- return f"Failed to read PDF ({os.path.basename(pdf_path)}): {str(e)}"
245
- full_text = "\n".join(text_chunks)
246
- if len(full_text) > MAX_CONTENT_CHARS:
247
- full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
248
- return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
249
-
250
- # =============================================================================
251
- # ์ด๋ฏธ์ง€/๋น„๋””์˜ค ํŒŒ์ผ ์ œํ•œ ๊ฒ€์‚ฌ
252
- # =============================================================================
253
- def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
254
- image_count = 0
255
- video_count = 0
256
- for path in paths:
257
- if path.endswith(".mp4"):
258
- video_count += 1
259
- elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", path, re.IGNORECASE):
260
- image_count += 1
261
- return image_count, video_count
262
-
263
- def count_files_in_history(history: list[dict]) -> tuple[int, int]:
264
- image_count = 0
265
- video_count = 0
266
- for item in history:
267
- if item["role"] != "user" or isinstance(item["content"], str):
268
- continue
269
- if isinstance(item["content"], list) and len(item["content"]) > 0:
270
- file_path = item["content"][0]
271
- if isinstance(file_path, str):
272
- if file_path.endswith(".mp4"):
273
- video_count += 1
274
- elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE):
275
- image_count += 1
276
- return image_count, video_count
277
-
278
- def validate_media_constraints(message: dict, history: list[dict]) -> bool:
279
- """์ด๋ฏธ์ง€/๋น„๋””์˜ค ์—…๋กœ๋“œ ์ œํ•œ ๊ฒ€์‚ฌ."""
280
- media_files = [f for f in message["files"]
281
- if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4")]
282
- new_image_count, new_video_count = count_files_in_new_message(media_files)
283
- history_image_count, history_video_count = count_files_in_history(history)
284
-
285
- image_count = history_image_count + new_image_count
286
- video_count = history_video_count + new_video_count
287
-
288
- if video_count > 1:
289
- gr.Warning("Only one video is supported.")
290
- return False
291
- if video_count == 1:
292
- if image_count > 0:
293
- gr.Warning("Mixing images and videos is not allowed.")
294
- return False
295
- if "<image>" in message["text"]:
296
- gr.Warning("Using <image> tags with video files is not supported.")
297
- return False
298
- if video_count == 0 and image_count > MAX_NUM_IMAGES:
299
- gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
300
- return False
301
- if "<image>" in message["text"]:
302
- image_files = [f for f in message["files"]
303
- if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
304
- image_tag_count = message["text"].count("<image>")
305
- if image_tag_count != len(image_files):
306
- gr.Warning("The number of <image> tags in the text does not match the number of image files.")
307
- return False
308
- return True
309
-
310
- # =============================================================================
311
- # ๋น„๋””์˜ค ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
312
- # =============================================================================
313
- def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
314
- vidcap = cv2.VideoCapture(video_path)
315
- fps = vidcap.get(cv2.CAP_PROP_FPS)
316
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
317
- frame_interval = max(int(fps), int(total_frames / 10))
318
- frames = []
319
- for i in range(0, total_frames, frame_interval):
320
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
321
- success, image = vidcap.read()
322
- if success:
323
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
324
- image = cv2.resize(image, (0, 0), fx=0.5, fy=0.5)
325
- pil_image = Image.fromarray(image)
326
- timestamp = round(i / fps, 2)
327
- frames.append((pil_image, timestamp))
328
- if len(frames) >= 5:
329
- break
330
- vidcap.release()
331
- return frames
332
-
333
- def process_video(video_path: str) -> tuple[list[dict], list[str]]:
334
- content = []
335
- temp_files = []
336
- frames = downsample_video(video_path)
337
- for pil_image, timestamp in frames:
338
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
339
- pil_image.save(temp_file.name)
340
- temp_files.append(temp_file.name)
341
- content.append({"type": "text", "text": f"Frame {timestamp}:"})
342
- content.append({"type": "image", "url": temp_file.name})
343
- return content, temp_files
344
-
345
- # =============================================================================
346
- # interleaved <image> ์ฒ˜๋ฆฌ ํ•จ์ˆ˜ (<image> ํƒœ๊ทธ์™€ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ ํ˜ผํ•ฉ ์ง€์›)
347
- # =============================================================================
348
- def process_interleaved_images(message: dict) -> list[dict]:
349
- parts = re.split(r"(<image>)", message["text"])
350
- content = []
351
- image_files = [f for f in message["files"]
352
- if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
353
- image_index = 0
354
- for part in parts:
355
- if part == "<image>" and image_index < len(image_files):
356
- content.append({"type": "image", "url": image_files[image_index]})
357
- image_index += 1
358
- elif part.strip():
359
- content.append({"type": "text", "text": part.strip()})
360
- else:
361
- if isinstance(part, str) and part != "<image>":
362
- content.append({"type": "text", "text": part})
363
- return content
364
-
365
- # =============================================================================
366
- # ํŒŒ์ผ ์ฒ˜๋ฆฌ -> content ์ƒ์„ฑ
367
- # =============================================================================
368
- def is_image_file(file_path: str) -> bool:
369
- return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
370
-
371
- def is_video_file(file_path: str) -> bool:
372
- return file_path.endswith(".mp4")
373
-
374
- def is_document_file(file_path: str) -> bool:
375
- return file_path.lower().endswith(".pdf") or file_path.lower().endswith(".csv") or file_path.lower().endswith(".txt")
376
-
377
- def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]:
378
- """์‚ฌ์šฉ์ž๊ฐ€ ์ƒˆ๋กœ ์ž…๋ ฅํ•œ ๋ฉ”์‹œ์ง€ + ์—…๋กœ๋“œ ํŒŒ์ผ๋“ค์„ ํ•˜๋‚˜์˜ content(list)๋กœ ๋ณ€ํ™˜."""
379
- temp_files = []
380
- if not message["files"]:
381
- return [{"type": "text", "text": message["text"]}], temp_files
382
-
383
- video_files = [f for f in message["files"] if is_video_file(f)]
384
- image_files = [f for f in message["files"] if is_image_file(f)]
385
- csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
386
- txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
387
- pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
388
-
389
- content_list = [{"type": "text", "text": message["text"]}]
390
-
391
- # ๋ฌธ์„œ๋“ค
392
- for csv_path in csv_files:
393
- content_list.append({"type": "text", "text": analyze_csv_file(csv_path)})
394
- for txt_path in txt_files:
395
- content_list.append({"type": "text", "text": analyze_txt_file(txt_path)})
396
- for pdf_path in pdf_files:
397
- content_list.append({"type": "text", "text": pdf_to_markdown(pdf_path)})
398
-
399
- # ๋น„๋””์˜ค ์ฒ˜๋ฆฌ
400
- if video_files:
401
- video_content, video_temp_files = process_video(video_files[0])
402
- content_list += video_content
403
- temp_files.extend(video_temp_files)
404
- return content_list, temp_files
405
-
406
- # ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
407
- if "<image>" in message["text"] and image_files:
408
- interleaved_content = process_interleaved_images({"text": message["text"], "files": image_files})
409
- if content_list and content_list[0]["type"] == "text":
410
- content_list = content_list[1:]
411
- return interleaved_content + content_list, temp_files
412
- else:
413
- for img_path in image_files:
414
- content_list.append({"type": "image", "url": img_path})
415
-
416
- return content_list, temp_files
417
-
418
- # =============================================================================
419
- # history -> LLM ๋ฉ”์‹œ์ง€ ๋ณ€ํ™˜
420
- # =============================================================================
421
- def process_history(history: list[dict]) -> list[dict]:
422
- """
423
- ๊ธฐ์กด ๋Œ€ํ™” ๊ธฐ๋ก์„ LLM์— ๋งž๊ฒŒ ๋ณ€ํ™˜.
424
- - user -> {"role":"user","content":[{type,text},...]}
425
- - assistant -> {"role":"assistant","content":[{type:"text",text},...]}
426
- """
427
- messages = []
428
- current_user_content = []
429
- for item in history:
430
- if item["role"] == "assistant":
431
- # ์‚ฌ์šฉ์ž content ๋ˆ„์ ๋ถ„์ด ์žˆ์œผ๋ฉด ํ•œ๋ฒˆ์— user๋กœ ์ถ”๊ฐ€
432
- if current_user_content:
433
- messages.append({"role": "user", "content": current_user_content})
434
- current_user_content = []
435
- # assistant ๋ฐ”๋กœ ์ถ”๊ฐ€
436
- messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
437
- else:
438
- content = item["content"]
439
- if isinstance(content, str):
440
- current_user_content.append({"type": "text", "text": content})
441
- elif isinstance(content, list) and len(content) > 0:
442
- file_path = content[0]
443
- if is_image_file(file_path):
444
- current_user_content.append({"type": "image", "url": file_path})
445
- else:
446
- current_user_content.append({"type": "text", "text": f"[File: {os.path.basename(file_path)}]"})
447
- if current_user_content:
448
- messages.append({"role": "user", "content": current_user_content})
449
- return messages
450
-
451
- # =============================================================================
452
- # ๋ชจ๋ธ ์ƒ์„ฑ ํ•จ์ˆ˜ (OOM ์บ์น˜)
453
- # =============================================================================
454
- def _model_gen_with_oom_catch(**kwargs):
455
- try:
456
- model.generate(**kwargs)
457
- except torch.cuda.OutOfMemoryError:
458
- raise RuntimeError("[OutOfMemoryError] GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ๋ถ€์กฑํ•ฉ๋‹ˆ๋‹ค.")
459
- finally:
460
- clear_cuda_cache()
461
-
462
- # =============================================================================
463
- # ๋ฉ”์ธ ์ถ”๋ก  ํ•จ์ˆ˜
464
- # =============================================================================
465
- @spaces.GPU(duration=120)
466
- def run(
467
- message: dict,
468
- history: list[dict],
469
- system_prompt: str = "",
470
- max_new_tokens: int = 512,
471
- use_web_search: bool = False,
472
- web_search_query: str = "",
473
- age_group: str = "20๋Œ€",
474
- mbti_personality: str = "INTP",
475
- sexual_openness: int = 2,
476
- image_gen: bool = False
477
- ) -> Iterator[str]:
478
- """
479
- LLM ์ถ”๋ก  ํ•จ์ˆ˜.
480
- - ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹œ, ์„œ๋ฒ„๊ฐ€ Base64(๋˜๋Š” data:image/... ํ˜•ํƒœ)๋ฅผ ์ง์ ‘ ๋ฐ˜ํ™˜ํ•œ๋‹ค๊ณ  ๊ฐ€์ •.
481
- - /tmp/... ํŒŒ์ผ์— ๋Œ€ํ•œ ์žฌ๋‹ค์šด๋กœ๋“œ๋ฅผ ์‹œ๋„ํ•˜์ง€ ์•Š์Œ (403 Forbidden ๋ฌธ์ œ ํšŒํ”ผ).
482
- """
483
- if not validate_media_constraints(message, history):
484
- yield ""
485
- return
486
-
487
- temp_files = []
488
- try:
489
- # 1) ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ + ํŽ˜๋ฅด์†Œ๋‚˜ ์ •๋ณด
490
- persona = (
491
- f"{system_prompt.strip()}\n\n"
492
- f"Gender: Female\n"
493
- f"Age Group: {age_group}\n"
494
- f"MBTI Persona: {mbti_personality}\n"
495
- f"Sexual Openness (1~5): {sexual_openness}\n"
496
- )
497
- combined_system_msg = f"[System Prompt]\n{persona.strip()}\n\n"
498
-
499
- # 2) ์›น ๊ฒ€์ƒ‰ (์˜ต์…˜)
500
- if use_web_search:
501
- user_text = message["text"]
502
- ws_query = extract_keywords(user_text)
503
- if ws_query.strip():
504
- logger.info(f"[Auto WebSearch Keyword] {ws_query!r}")
505
- ws_result = do_web_search(ws_query)
506
- combined_system_msg += f"[Search top-20 Full Items]\n{ws_result}\n\n"
507
- combined_system_msg += (
508
- "[์ฐธ๊ณ : ์œ„ ๊ฒ€์ƒ‰๊ฒฐ๊ณผ link๋ฅผ ์ถœ์ฒ˜๋กœ ์ธ์šฉํ•˜์—ฌ ๋‹ต๋ณ€]\n"
509
- "[์ค‘์š” ์ง€์‹œ์‚ฌํ•ญ]\n"
510
- "1. ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์—์„œ ์ฐพ์€ ์ •๋ณด์˜ ์ถœ์ฒ˜๋ฅผ ๋ฐ˜๋“œ์‹œ ์ธ์šฉ.\n"
511
- "2. '[์ถœ์ฒ˜ ์ œ๋ชฉ](๋งํฌ)' ํ˜•์‹์œผ๋กœ ๋งํฌ.\n"
512
- "3. ๋‹ต๋ณ€ ๋งˆ์ง€๋ง‰์— '์ฐธ๊ณ  ์ž๋ฃŒ:' ์„น์…˜.\n"
513
- )
514
- else:
515
- combined_system_msg += "[No valid keywords found, skipping WebSearch]\n\n"
516
-
517
- # 3) ๊ธฐ์กด history + ์ƒˆ user ๋ฉ”์‹œ์ง€
518
- messages = []
519
- if combined_system_msg.strip():
520
- messages.append({"role": "system", "content": [{"type": "text", "text": combined_system_msg.strip()}]})
521
- messages.extend(process_history(history))
522
-
523
- user_content, user_temp_files = process_new_user_message(message)
524
- temp_files.extend(user_temp_files)
525
-
526
- for item in user_content:
527
- if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS:
528
- item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..."
529
-
530
- messages.append({"role": "user", "content": user_content})
531
-
532
- # 4) ํ† ํฌ๋‚˜์ด์ง•
533
- inputs = processor.apply_chat_template(
534
- messages,
535
- add_generation_prompt=True,
536
- tokenize=True,
537
- return_dict=True,
538
- return_tensors="pt",
539
- ).to(device=model.device, dtype=torch.bfloat16)
540
- if inputs.input_ids.shape[1] > MAX_INPUT_LENGTH:
541
- inputs.input_ids = inputs.input_ids[:, -MAX_INPUT_LENGTH:]
542
- if 'attention_mask' in inputs:
543
- inputs.attention_mask = inputs.attention_mask[:, -MAX_INPUT_LENGTH:]
544
-
545
- streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
546
- gen_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
547
-
548
- t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
549
- t.start()
550
-
551
- # ์ŠคํŠธ๋ฆฌ๋ฐ ์ถœ๋ ฅ
552
- output_so_far = ""
553
- for new_text in streamer:
554
- output_so_far += new_text
555
- yield output_so_far
556
-
557
- # 5) ์ด๋ฏธ์ง€ ์ƒ์„ฑ (Base64)
558
- if image_gen:
559
- last_user_text = message["text"].strip()
560
- if not last_user_text:
561
- yield output_so_far + "\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹คํŒจ: Empty user prompt)"
562
- else:
563
- try:
564
- width, height = 512, 512
565
- guidance, steps, seed = 7.5, 30, 42
566
-
567
- logger.info(f"Generating image with prompt: {last_user_text}")
568
-
569
- # API ํ˜ธ์ถœํ•ด์„œ (base64) ์ด๋ฏธ์ง€ ์ƒ์„ฑ
570
- image_result, seed_info = generate_image(
571
- prompt=last_user_text,
572
- width=width,
573
- height=height,
574
- guidance=guidance,
575
- inference_steps=steps,
576
- seed=seed
577
- )
578
-
579
- logger.info(f"Received image data type: {type(image_result)}")
580
-
581
- # Base64 or data:image/... ์ฒ˜๋ฆฌ
582
- if image_result:
583
- if isinstance(image_result, str):
584
- # ์ด๋ฏธ data:image/๋กœ ์‹œ์ž‘ํ•˜๋ฉด ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
585
- if image_result.startswith("data:image/"):
586
- final_md = f"\n\n**[์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€]**\n\n![์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€]({image_result})"
587
- yield output_so_far + final_md
588
- else:
589
- # ์ˆœ์ˆ˜ base64๋กœ ํŒ๋‹จ(๋‹จ, ์ผ๋ฐ˜ URL์ด๋‚˜ '/tmp/...'์ด๋ฉด ์ฒ˜๋ฆฌ ๋ถˆ๊ฐ€)
590
- if len(image_result) > 100 and "/" not in image_result:
591
- # base64
592
- image_data = "data:image/webp;base64," + image_result
593
- final_md = f"\n\n**[์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€]**\n\n![์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€]({image_data})"
594
- yield output_so_far + final_md
595
- else:
596
- # ๊ทธ ์™ธ (ex. http://..., /tmp/...) -> 403 ๋ฌธ์ œ ๋ฐœ์ƒํ•˜๋ฏ€๋กœ ํ‘œ์‹œ ์•ˆ ํ•จ
597
- yield output_so_far + "\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ๊ฒฐ๊ณผ๊ฐ€ base64 ํ˜•์‹์ด ์•„๋‹™๋‹ˆ๋‹ค)"
598
- else:
599
- yield output_so_far + "\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ๊ฒฐ๊ณผ๊ฐ€ ๋ฌธ์ž์—ด์ด ์•„๋‹˜)"
600
- else:
601
- yield output_so_far + f"\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹คํŒจ: {seed_info})"
602
-
603
- except Exception as e:
604
- logger.error(f"Image generation error: {e}")
605
- yield output_so_far + f"\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e})"
606
-
607
- except Exception as e:
608
- logger.error(f"Error in run: {str(e)}")
609
- yield f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
610
- finally:
611
- for tmp in temp_files:
612
- try:
613
- if os.path.exists(tmp):
614
- os.unlink(tmp)
615
- logger.info(f"Deleted temp file: {tmp}")
616
- except Exception as ee:
617
- logger.warning(f"Failed to delete temp file {tmp}: {ee}")
618
- try:
619
- del inputs, streamer
620
- except Exception:
621
- pass
622
- clear_cuda_cache()
623
-
624
- # =============================================================================
625
- # ์˜ˆ์‹œ๋“ค
626
- # =============================================================================
627
- examples = [
628
- [
629
- {
630
- "text": "Compare the contents of the two PDF files.",
631
- "files": [
632
- "assets/additional-examples/before.pdf",
633
- "assets/additional-examples/after.pdf",
634
- ],
635
- }
636
- ],
637
- [
638
- {
639
- "text": "Summarize and analyze the contents of the CSV file.",
640
- "files": ["assets/additional-examples/sample-csv.csv"],
641
- }
642
- ],
643
- # ... ๋‚˜๋จธ์ง€ ์˜ˆ์‹œ ํ•„์š”ํ•˜๋‹ค๋ฉด ์ถ”๊ฐ€ ...
644
- ]
645
-
646
- # =============================================================================
647
- # Gradio UI (Blocks) ๊ตฌ์„ฑ
648
- # =============================================================================
649
-
650
- css = """
651
- .gradio-container {
652
- background: rgba(255, 255, 255, 0.7);
653
- padding: 30px 40px;
654
- margin: 20px auto;
655
- width: 100% !important;
656
- max-width: none !important;
657
- }
658
- """
659
- title_html = """
660
- <h1 align="center" style="margin-bottom: 0.2em; font-size: 1.6em;"> ๐Ÿ’˜ HeartSync : Love Dating AI ๐Ÿ’˜ </h1>
661
- <p align="center" style="font-size:1.1em; color:#555;">
662
- โœ… FLUX Image Generation โœ… Reasoning & Uncensored โœ… Multimodal & VLM โœ… Deep-Research & RAG <br>
663
- </p>
664
- """
665
-
666
- with gr.Blocks(css=css, title="HeartSync") as demo:
667
- gr.Markdown(title_html)
668
-
669
- # ๋ณ„๋„ ๊ฐค๋Ÿฌ๋ฆฌ ์˜ˆ์‹œ (ํ•„์š” ์‹œ ์‚ฌ์šฉ)
670
- generated_images = gr.Gallery(
671
- label="์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€",
672
- show_label=True,
673
- visible=False,
674
- elem_id="generated_images",
675
- columns=2,
676
- height="auto",
677
- object_fit="contain"
678
- )
679
-
680
- with gr.Row():
681
- web_search_checkbox = gr.Checkbox(label="Deep Research", value=False)
682
- image_gen_checkbox = gr.Checkbox(label="Image Gen", value=False)
683
-
684
- base_system_prompt_box = gr.Textbox(
685
- lines=3,
686
- value="You are a deep thinking AI...\nํŽ˜๋ฅด์†Œ๋‚˜: ๋‹น์‹ ์€ ๋‹ฌ์ฝคํ•˜๊ณ ...",
687
- label="๊ธฐ๋ณธ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ",
688
- visible=False
689
- )
690
- with gr.Row():
691
- age_group_dropdown = gr.Dropdown(
692
- label="์—ฐ๋ น๋Œ€ ์„ ํƒ (๊ธฐ๋ณธ 20๋Œ€)",
693
- choices=["10๋Œ€", "20๋Œ€", "30~40๋Œ€", "50~60๋Œ€", "70๋Œ€ ์ด์ƒ"],
694
- value="20๋Œ€",
695
- interactive=True
696
- )
697
- mbti_choices = [
698
- "INTJ (์šฉ์˜์ฃผ๋„ํ•œ ์ „๋žต๊ฐ€)",
699
- "INTP (๋…ผ๋ฆฌ์ ์ธ ์‚ฌ์ƒ‰๊ฐ€)",
700
- "ENTJ (๋Œ€๋‹ดํ•œ ํ†ต์†”์ž)",
701
- "ENTP (๋œจ๊ฑฐ์šด ๋…ผ์Ÿ๊ฐ€)",
702
- "INFJ (์„ ์˜์˜ ์˜นํ˜ธ์ž)",
703
- "INFP (์—ด์ •์ ์ธ ์ค‘์žฌ์ž)",
704
- "ENFJ (์ •์˜๋กœ์šด ์‚ฌํšŒ์šด๋™๊ฐ€)",
705
- "ENFP (์žฌ๊ธฐ๋ฐœ๋ž„ํ•œ ํ™œ๋™๊ฐ€)",
706
- "ISTJ (์ฒญ๋ ด๊ฒฐ๋ฐฑํ•œ ๋…ผ๋ฆฌ์ฃผ์˜์ž)",
707
- "ISFJ (์šฉ๊ฐํ•œ ์ˆ˜ํ˜ธ์ž)",
708
- "ESTJ (์—„๊ฒฉํ•œ ๊ด€๋ฆฌ์ž)",
709
- "ESFJ (์‚ฌ๊ต์ ์ธ ์™ธ๊ต๊ด€)",
710
- "ISTP (๋งŒ๋Šฅ ์žฌ์ฃผ๊พผ)",
711
- "ISFP (ํ˜ธ๊ธฐ์‹ฌ ๋งŽ์€ ์˜ˆ์ˆ ๊ฐ€)",
712
- "ESTP (๋ชจํ—˜์„ ์ฆ๊ธฐ๋Š” ์‚ฌ์—…๊ฐ€)",
713
- "ESFP (์ž์œ ๋กœ์šด ์˜ํ˜ผ์˜ ์—ฐ์˜ˆ์ธ)"
714
- ]
715
- mbti_dropdown = gr.Dropdown(
716
- label="AI ํŽ˜๋ฅด์†Œ๋‚˜ MBTI (๊ธฐ๋ณธ INTP)",
717
- choices=mbti_choices,
718
- value="INTP (๋…ผ๋ฆฌ์ ์ธ ์‚ฌ์ƒ‰๊ฐ€)",
719
- interactive=True
720
- )
721
- sexual_openness_slider = gr.Slider(
722
- minimum=1, maximum=5, step=1, value=2,
723
- label="์„น์Šˆ์–ผ ๊ด€์‹ฌ๋„/๊ฐœ๋ฐฉ์„ฑ (1~5, ๊ธฐ๋ณธ=2)",
724
- interactive=True
725
- )
726
- max_tokens_slider = gr.Slider(
727
- label="Max New Tokens",
728
- minimum=100, maximum=8000, step=50, value=1000,
729
- visible=False
730
- )
731
- web_search_text = gr.Textbox(
732
- lines=1,
733
- label="(Unused) Web Search Query",
734
- placeholder="No direct input needed",
735
- visible=False
736
- )
737
-
738
- def modified_run(
739
- message, history, system_prompt, max_new_tokens,
740
- use_web_search, web_search_query,
741
- age_group, mbti_personality, sexual_openness, image_gen
742
- ):
743
- """
744
- run() ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ํ…์ŠคํŠธ ์ŠคํŠธ๋ฆผ์„ ๋ฐ›๊ณ ,
745
- ํ•„์š” ์‹œ ์ถ”๊ฐ€ ์ฒ˜๋ฆฌ ํ›„ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜ (๊ฐค๋Ÿฌ๋ฆฌ ์—…๋ฐ์ดํŠธ ๋“ฑ).
746
- """
747
- output_so_far = ""
748
- gallery_update = gr.Gallery(visible=False, value=[])
749
- yield output_so_far, gallery_update
750
-
751
- text_generator = run(
752
- message, history,
753
- system_prompt, max_new_tokens,
754
- use_web_search, web_search_query,
755
- age_group, mbti_personality,
756
- sexual_openness, image_gen
757
- )
758
-
759
- for text_chunk in text_generator:
760
- output_so_far = text_chunk
761
- yield output_so_far, gallery_update
762
-
763
- # ๋งŒ์•ฝ run() ๋‚ด๋ถ€์—์„œ Base64 ์ด๋ฏธ์ง€๋ฅผ ์ด๋ฏธ ๋Œ€ํ™”์ฐฝ์— ์‚ฝ์ž…ํ–ˆ๋‹ค๋ฉด,
764
- # ์—ฌ๊ธฐ์„œ ๊ฐค๋Ÿฌ๋ฆฌ์— ๋”ฐ๋กœ ํ‘œ์‹œํ•  ํ•„์š”๋Š” ์—†์„ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
765
- # run() ๋‚ด๋ถ€์—์„œ์˜ image_result๋ฅผ ๊ฐ€์ ธ์˜ค๋ ค๋ฉด, run() ํ•จ์ˆ˜๊ฐ€ ํ•ด๋‹น ์ •๋ณด๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋„๋ก ์ถ”๊ฐ€ ์ˆ˜์ •์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
766
-
767
- chat = gr.ChatInterface(
768
- fn=modified_run,
769
- type="messages",
770
- chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
771
- textbox=gr.MultimodalTextbox(
772
- file_types=[".webp", ".png", ".jpg", ".jpeg", ".gif", ".mp4", ".csv", ".txt", ".pdf"],
773
- file_count="multiple",
774
- autofocus=True
775
- ),
776
- multimodal=True,
777
- additional_inputs=[
778
- base_system_prompt_box,
779
- max_tokens_slider,
780
- web_search_checkbox,
781
- web_search_text,
782
- age_group_dropdown,
783
- mbti_dropdown,
784
- sexual_openness_slider,
785
- image_gen_checkbox,
786
- ],
787
- additional_outputs=[generated_images],
788
- stop_btn=False,
789
- title='<a href="https://discord.gg/openfreeai" target="_blank">https://discord.gg/openfreeai</a>',
790
- examples=examples,
791
- run_examples_on_click=False,
792
- cache_examples=False,
793
- css_paths=None,
794
- delete_cache=(1800, 1800),
795
- )
796
-
797
- with gr.Row(elem_id="examples_row"):
798
- with gr.Column(scale=12, elem_id="examples_container"):
799
- gr.Markdown("### Example Inputs (click to load)")
800
-
801
- if __name__ == "__main__":
802
- demo.launch(share=True)