ginipick commited on
Commit
f1068cb
·
verified ·
1 Parent(s): 080bfa1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +978 -0
app.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import re
5
+ import tempfile
6
+ import gc # Added garbage collector
7
+ from collections.abc import Iterator
8
+ from threading import Thread
9
+ import json
10
+ import requests
11
+ import cv2
12
+ import base64
13
+ import logging
14
+ import time
15
+ from urllib.parse import quote # Added for URL encoding
16
+
17
+ import gradio as gr
18
+ import spaces
19
+ import torch
20
+ from loguru import logger
21
+ from PIL import Image
22
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
23
+
24
+ # CSV/TXT/PDF analysis
25
+ import pandas as pd
26
+ import PyPDF2
27
+
28
+ # =============================================================================
29
+ # (New) Image API related functions
30
+ # =============================================================================
31
+ from gradio_client import Client
32
+
33
+ API_URL = "http://211.233.58.201:7896"
34
+
35
+ logging.basicConfig(
36
+ level=logging.DEBUG,
37
+ format='%(asctime)s - %(levelname)s - %(message)s'
38
+ )
39
+
40
+ def test_api_connection() -> str:
41
+ """Test API server connection"""
42
+ try:
43
+ client = Client(API_URL)
44
+ return "API connection successful: Operating normally"
45
+ except Exception as e:
46
+ logging.error(f"API connection test failed: {e}")
47
+ return f"API connection failed: {e}"
48
+
49
+ def generate_image(prompt: str, width: float, height: float, guidance: float, inference_steps: float, seed: float):
50
+ """Image generation function (flexible return types)"""
51
+ if not prompt:
52
+ return None, "Error: A prompt is required."
53
+ try:
54
+ logging.info(f"Calling image generation API with prompt: {prompt}")
55
+
56
+ client = Client(API_URL)
57
+ result = client.predict(
58
+ prompt=prompt,
59
+ width=int(width),
60
+ height=int(height),
61
+ guidance=float(guidance),
62
+ inference_steps=int(inference_steps),
63
+ seed=int(seed),
64
+ do_img2img=False,
65
+ init_image=None,
66
+ image2image_strength=0.8,
67
+ resize_img=True,
68
+ api_name="/generate_image"
69
+ )
70
+
71
+ logging.info(f"Image generation result: {type(result)}, length: {len(result) if isinstance(result, (list, tuple)) else 'unknown'}")
72
+
73
+ # Handle cases where the result is a tuple or list
74
+ if isinstance(result, (list, tuple)) and len(result) > 0:
75
+ image_data = result[0] # The first element is the image data
76
+ seed_info = result[1] if len(result) > 1 else "Unknown seed"
77
+ return image_data, seed_info
78
+ else:
79
+ # When a single value is returned
80
+ return result, "Unknown seed"
81
+
82
+ except Exception as e:
83
+ logging.error(f"Image generation failed: {str(e)}")
84
+ return None, f"Error: {str(e)}"
85
+
86
+ # Base64 padding fix function
87
+ def fix_base64_padding(data):
88
+ """Fix the padding of a Base64 string."""
89
+ if isinstance(data, bytes):
90
+ data = data.decode('utf-8')
91
+
92
+ # Remove the prefix if present
93
+ if "base64," in data:
94
+ data = data.split("base64,", 1)[1]
95
+
96
+ # Add padding characters (to make the length a multiple of 4)
97
+ missing_padding = len(data) % 4
98
+ if missing_padding:
99
+ data += '=' * (4 - missing_padding)
100
+
101
+ return data
102
+
103
+ # =============================================================================
104
+ # Memory cleanup function
105
+ # =============================================================================
106
+ def clear_cuda_cache():
107
+ """Explicitly clear the CUDA cache."""
108
+ if torch.cuda.is_available():
109
+ torch.cuda.empty_cache()
110
+ gc.collect()
111
+
112
+ # =============================================================================
113
+ # SerpHouse related functions
114
+ # =============================================================================
115
+ SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
116
+
117
+ def extract_keywords(text: str, top_k: int = 5) -> str:
118
+ """Simple keyword extraction: only keep English, Korean, numbers, and spaces."""
119
+ text = re.sub(r"[^a-zA-Z0-9가-힣\s]", "", text)
120
+ tokens = text.split()
121
+ return " ".join(tokens[:top_k])
122
+
123
+ def do_web_search(query: str) -> str:
124
+ """Call the SerpHouse LIVE API to return Markdown formatted search results"""
125
+ try:
126
+ url = "https://api.serphouse.com/serp/live"
127
+ params = {
128
+ "q": query,
129
+ "domain": "google.com",
130
+ "serp_type": "web",
131
+ "device": "desktop",
132
+ "lang": "en",
133
+ "num": "20"
134
+ }
135
+ headers = {"Authorization": f"Bearer {SERPHOUSE_API_KEY}"}
136
+ logger.info(f"Calling SerpHouse API with query: {query}")
137
+ response = requests.get(url, headers=headers, params=params, timeout=60)
138
+ response.raise_for_status()
139
+ data = response.json()
140
+ results = data.get("results", {})
141
+ organic = None
142
+ if isinstance(results, dict) and "organic" in results:
143
+ organic = results["organic"]
144
+ elif isinstance(results, dict) and "results" in results:
145
+ if isinstance(results["results"], dict) and "organic" in results["results"]:
146
+ organic = results["results"]["organic"]
147
+ elif "organic" in data:
148
+ organic = data["organic"]
149
+ if not organic:
150
+ logger.warning("Organic results not found in response.")
151
+ return "No web search results available or the API response structure is unexpected."
152
+ max_results = min(20, len(organic))
153
+ limited_organic = organic[:max_results]
154
+ summary_lines = []
155
+ for idx, item in enumerate(limited_organic, start=1):
156
+ title = item.get("title", "No Title")
157
+ link = item.get("link", "#")
158
+ snippet = item.get("snippet", "No Description")
159
+ displayed_link = item.get("displayed_link", link)
160
+ summary_lines.append(
161
+ f"### Result {idx}: {title}\n\n"
162
+ f"{snippet}\n\n"
163
+ f"**Source**: [{displayed_link}]({link})\n\n"
164
+ f"---\n"
165
+ )
166
+ instructions = """
167
+ # Web Search Results
168
+ Below are the search results. Use this information to answer the query:
169
+ 1. Refer to each result's title, description, and source link.
170
+ 2. In your answer, explicitly cite the source of any used information (e.g., "[Source Title](link)").
171
+ 3. Include the actual source links in your response.
172
+ 4. Synthesize information from multiple sources.
173
+ 5. At the end include a "References:" section listing the main source links.
174
+ """
175
+ return instructions + "\n".join(summary_lines)
176
+ except Exception as e:
177
+ logger.error(f"Web search failed: {e}")
178
+ return f"Web search failed: {str(e)}"
179
+
180
+ # =============================================================================
181
+ # Model and processor loading
182
+ # =============================================================================
183
+ MAX_CONTENT_CHARS = 2000
184
+ MAX_INPUT_LENGTH = 2096
185
+ model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
186
+ processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
187
+ model = Gemma3ForConditionalGeneration.from_pretrained(
188
+ model_id,
189
+ device_map="auto",
190
+ torch_dtype=torch.bfloat16,
191
+ attn_implementation="eager"
192
+ )
193
+ MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
194
+
195
+ # =============================================================================
196
+ # CSV, TXT, PDF analysis functions
197
+ # =============================================================================
198
+ def analyze_csv_file(path: str) -> str:
199
+ try:
200
+ df = pd.read_csv(path)
201
+ if df.shape[0] > 50 or df.shape[1] > 10:
202
+ df = df.iloc[:50, :10]
203
+ df_str = df.to_string()
204
+ if len(df_str) > MAX_CONTENT_CHARS:
205
+ df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
206
+ return f"**[CSV File: {os.path.basename(path)}]**\n\n{df_str}"
207
+ except Exception as e:
208
+ return f"CSV file read failed ({os.path.basename(path)}): {str(e)}"
209
+
210
+ def analyze_txt_file(path: str) -> str:
211
+ try:
212
+ with open(path, "r", encoding="utf-8") as f:
213
+ text = f.read()
214
+ if len(text) > MAX_CONTENT_CHARS:
215
+ text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
216
+ return f"**[TXT File: {os.path.basename(path)}]**\n\n{text}"
217
+ except Exception as e:
218
+ return f"TXT file read failed ({os.path.basename(path)}): {str(e)}"
219
+
220
+ def pdf_to_markdown(pdf_path: str) -> str:
221
+ text_chunks = []
222
+ try:
223
+ with open(pdf_path, "rb") as f:
224
+ reader = PyPDF2.PdfReader(f)
225
+ max_pages = min(5, len(reader.pages))
226
+ for page_num in range(max_pages):
227
+ page_text = reader.pages[page_num].extract_text() or ""
228
+ page_text = page_text.strip()
229
+ if page_text:
230
+ if len(page_text) > MAX_CONTENT_CHARS // max_pages:
231
+ page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(truncated)"
232
+ text_chunks.append(f"## Page {page_num+1}\n\n{page_text}\n")
233
+ if len(reader.pages) > max_pages:
234
+ text_chunks.append(f"\n...(Displaying only {max_pages} out of {len(reader.pages)} pages)...")
235
+ except Exception as e:
236
+ return f"PDF file read failed ({os.path.basename(pdf_path)}): {str(e)}"
237
+ full_text = "\n".join(text_chunks)
238
+ if len(full_text) > MAX_CONTENT_CHARS:
239
+ full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
240
+ return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
241
+
242
+ # =============================================================================
243
+ # Check media file limits
244
+ # =============================================================================
245
+ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
246
+ image_count = 0
247
+ video_count = 0
248
+ for path in paths:
249
+ if path.endswith(".mp4"):
250
+ video_count += 1
251
+ elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", path, re.IGNORECASE):
252
+ image_count += 1
253
+ return image_count, video_count
254
+
255
+ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
256
+ image_count = 0
257
+ video_count = 0
258
+ for item in history:
259
+ if item["role"] != "user" or isinstance(item["content"], str):
260
+ continue
261
+ if isinstance(item["content"], list) and len(item["content"]) > 0:
262
+ file_path = item["content"][0]
263
+ if isinstance(file_path, str):
264
+ if file_path.endswith(".mp4"):
265
+ video_count += 1
266
+ elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE):
267
+ image_count += 1
268
+ return image_count, video_count
269
+
270
+ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
271
+ media_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4")]
272
+ new_image_count, new_video_count = count_files_in_new_message(media_files)
273
+ history_image_count, history_video_count = count_files_in_history(history)
274
+ image_count = history_image_count + new_image_count
275
+ video_count = history_video_count + new_video_count
276
+ if video_count > 1:
277
+ gr.Warning("Only one video file is supported.")
278
+ return False
279
+ if video_count == 1:
280
+ if image_count > 0:
281
+ gr.Warning("Mixing images and a video is not allowed.")
282
+ return False
283
+ if "<image>" in message["text"]:
284
+ gr.Warning("The <image> tag cannot be used together with a video file.")
285
+ return False
286
+ if video_count == 0 and image_count > MAX_NUM_IMAGES:
287
+ gr.Warning(f"You can upload a maximum of {MAX_NUM_IMAGES} images.")
288
+ return False
289
+ if "<image>" in message["text"]:
290
+ image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
291
+ image_tag_count = message["text"].count("<image>")
292
+ if image_tag_count != len(image_files):
293
+ gr.Warning("The number of <image> tags does not match the number of image files provided.")
294
+ return False
295
+ return True
296
+
297
+ # =============================================================================
298
+ # Video processing functions
299
+ # =============================================================================
300
+ def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
301
+ vidcap = cv2.VideoCapture(video_path)
302
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
303
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
304
+ frame_interval = max(int(fps), int(total_frames / 10))
305
+ frames = []
306
+ for i in range(0, total_frames, frame_interval):
307
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
308
+ success, image = vidcap.read()
309
+ if success:
310
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
311
+ image = cv2.resize(image, (0, 0), fx=0.5, fy=0.5)
312
+ pil_image = Image.fromarray(image)
313
+ timestamp = round(i / fps, 2)
314
+ frames.append((pil_image, timestamp))
315
+ if len(frames) >= 5:
316
+ break
317
+ vidcap.release()
318
+ return frames
319
+
320
+ def process_video(video_path: str) -> tuple[list[dict], list[str]]:
321
+ content = []
322
+ temp_files = []
323
+ frames = downsample_video(video_path)
324
+ for pil_image, timestamp in frames:
325
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
326
+ pil_image.save(temp_file.name)
327
+ temp_files.append(temp_file.name)
328
+ content.append({"type": "text", "text": f"Frame {timestamp}:"})
329
+ content.append({"type": "image", "url": temp_file.name})
330
+ return content, temp_files
331
+
332
+ # =============================================================================
333
+ # Interleaved <image> processing function
334
+ # =============================================================================
335
+ def process_interleaved_images(message: dict) -> list[dict]:
336
+ parts = re.split(r"(<image>)", message["text"])
337
+ content = []
338
+ image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
339
+ image_index = 0
340
+ for part in parts:
341
+ if part == "<image>" and image_index < len(image_files):
342
+ content.append({"type": "image", "url": image_files[image_index]})
343
+ image_index += 1
344
+ elif part.strip():
345
+ content.append({"type": "text", "text": part.strip()})
346
+ else:
347
+ if isinstance(part, str) and part != "<image>":
348
+ content.append({"type": "text", "text": part})
349
+ return content
350
+
351
+ # =============================================================================
352
+ # File processing -> content creation
353
+ # =============================================================================
354
+ def is_image_file(file_path: str) -> bool:
355
+ return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
356
+
357
+ def is_video_file(file_path: str) -> bool:
358
+ return file_path.endswith(".mp4")
359
+
360
+ def is_document_file(file_path: str) -> bool:
361
+ return file_path.lower().endswith(".pdf") or file_path.lower().endswith(".csv") or file_path.lower().endswith(".txt")
362
+
363
+ def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]:
364
+ temp_files = []
365
+ if not message["files"]:
366
+ return [{"type": "text", "text": message["text"]}], temp_files
367
+ video_files = [f for f in message["files"] if is_video_file(f)]
368
+ image_files = [f for f in message["files"] if is_image_file(f)]
369
+ csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
370
+ txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
371
+ pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
372
+ content_list = [{"type": "text", "text": message["text"]}]
373
+ for csv_path in csv_files:
374
+ content_list.append({"type": "text", "text": analyze_csv_file(csv_path)})
375
+ for txt_path in txt_files:
376
+ content_list.append({"type": "text", "text": analyze_txt_file(txt_path)})
377
+ for pdf_path in pdf_files:
378
+ content_list.append({"type": "text", "text": pdf_to_markdown(pdf_path)})
379
+ if video_files:
380
+ video_content, video_temp_files = process_video(video_files[0])
381
+ content_list += video_content
382
+ temp_files.extend(video_temp_files)
383
+ return content_list, temp_files
384
+ if "<image>" in message["text"] and image_files:
385
+ interleaved_content = process_interleaved_images({"text": message["text"], "files": image_files})
386
+ if content_list and content_list[0]["type"] == "text":
387
+ content_list = content_list[1:]
388
+ return interleaved_content + content_list, temp_files
389
+ else:
390
+ for img_path in image_files:
391
+ content_list.append({"type": "image", "url": img_path})
392
+ return content_list, temp_files
393
+
394
+ # =============================================================================
395
+ # Convert history to LLM messages
396
+ # =============================================================================
397
+ def process_history(history: list[dict]) -> list[dict]:
398
+ messages = []
399
+ current_user_content = []
400
+ for item in history:
401
+ if item["role"] == "assistant":
402
+ if current_user_content:
403
+ messages.append({"role": "user", "content": current_user_content})
404
+ current_user_content = []
405
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
406
+ else:
407
+ content = item["content"]
408
+ if isinstance(content, str):
409
+ current_user_content.append({"type": "text", "text": content})
410
+ elif isinstance(content, list) and len(content) > 0:
411
+ file_path = content[0]
412
+ if is_image_file(file_path):
413
+ current_user_content.append({"type": "image", "url": file_path})
414
+ else:
415
+ current_user_content.append({"type": "text", "text": f"[File: {os.path.basename(file_path)}]"})
416
+ if current_user_content:
417
+ messages.append({"role": "user", "content": current_user_content})
418
+ return messages
419
+
420
+ # =============================================================================
421
+ # Model generation function (with OOM catching)
422
+ # =============================================================================
423
+ def _model_gen_with_oom_catch(**kwargs):
424
+ try:
425
+ model.generate(**kwargs)
426
+ except torch.cuda.OutOfMemoryError:
427
+ raise RuntimeError("[OutOfMemoryError] Insufficient GPU memory.")
428
+ finally:
429
+ clear_cuda_cache()
430
+
431
+ # =============================================================================
432
+ # Yahoo Finance 함수: yfinance를 활용하여 주식 가격 조회
433
+ # =============================================================================
434
+ import yfinance as yf
435
+
436
+ def get_stock_price(ticker: str) -> float:
437
+ """
438
+ 주어진 티커(ticker)의 최신 종가를 반환합니다.
439
+ yfinance 라이브러리를 사용하며, 별도의 토큰 없이 데이터를 가져옵니다.
440
+ """
441
+ stock = yf.Ticker(ticker)
442
+ data = stock.history(period="1d")
443
+ if not data.empty:
444
+ return data['Close'].iloc[-1]
445
+ return float('nan')
446
+
447
+ # =============================================================================
448
+ # 함수 호출 예제: 제품 조회 및 주식 가격 조회 함수 처리
449
+ # =============================================================================
450
+ def get_product_name_by_PID(PID: str) -> str:
451
+ """Finds the name of a product by its Product ID"""
452
+ product_catalog = {
453
+ "807ZPKBL9V": "SuperWidget",
454
+ "1234567890": "MegaGadget"
455
+ }
456
+ return product_catalog.get(PID, "Unknown product")
457
+
458
+ def handle_function_call(text: str) -> str:
459
+ """
460
+ Detects and processes function call blocks in the text.
461
+ 처리 대상:
462
+ - get_product_name_by_PID(PID="...")
463
+ - get_stock_price(ticker="...")
464
+ 그리고 결과를 tool_output 블록으로 반환합니다.
465
+ """
466
+ import re, io
467
+ from contextlib import redirect_stdout
468
+ pattern = r"```tool_code\s*(.*?)\s*```"
469
+ match = re.search(pattern, text, re.DOTALL)
470
+ if match:
471
+ code = match.group(1).strip()
472
+ # 제품 조회 함수 처리
473
+ if code.startswith("get_product_name_by_PID("):
474
+ pid_match = re.search(r'PID\s*=\s*"(.*?)"', code)
475
+ if pid_match:
476
+ pid = pid_match.group(1)
477
+ result = get_product_name_by_PID(pid)
478
+ return f"```tool_output\n{result}\n```"
479
+ # 주식 가격 조회 함수 처리
480
+ elif code.startswith("get_stock_price("):
481
+ ticker_match = re.search(r'ticker\s*=\s*"(.*?)"', code)
482
+ if ticker_match:
483
+ ticker = ticker_match.group(1)
484
+ result = get_stock_price(ticker)
485
+ return f"```tool_output\n{result}\n```"
486
+ return ""
487
+
488
+ # =============================================================================
489
+ # Main inference function
490
+ # =============================================================================
491
+ @spaces.GPU(duration=120)
492
+ def run(
493
+ message: dict,
494
+ history: list[dict],
495
+ system_prompt: str = "",
496
+ max_new_tokens: int = 512,
497
+ use_web_search: bool = False,
498
+ web_search_query: str = "",
499
+ age_group: str = "20s",
500
+ mbti_personality: str = "INTP",
501
+ sexual_openness: int = 2,
502
+ image_gen: bool = False # "Image Gen" checkbox status
503
+ ) -> Iterator[str]:
504
+ if not validate_media_constraints(message, history):
505
+ yield ""
506
+ return
507
+ temp_files = []
508
+ try:
509
+ # Append persona information to the system prompt
510
+ persona = (
511
+ f"{system_prompt.strip()}\n\n"
512
+ f"Gender: Female\n"
513
+ f"Age Group: {age_group}\n"
514
+ f"MBTI Persona: {mbti_personality}\n"
515
+ f"Sexual Openness (1-5): {sexual_openness}\n"
516
+ )
517
+ # 추가: 함수 호출 예제 안내문 포함
518
+ additional_func_info = (
519
+ "\nNote: The following functions are available for use:\n"
520
+ "1. get_product_name_by_PID(PID: str)\n"
521
+ " Format: ```tool_code\nget_product_name_by_PID(PID=\"<PRODUCT_ID>\")\n``` \n"
522
+ "2. get_stock_price(ticker: str)\n"
523
+ " Format: ```tool_code\nget_stock_price(ticker=\"<TICKER>\")\n```"
524
+ )
525
+ combined_system_msg = f"[System Prompt]\n{persona.strip()}{additional_func_info}\n\n"
526
+
527
+ if use_web_search:
528
+ user_text = message["text"]
529
+ ws_query = extract_keywords(user_text)
530
+ if ws_query.strip():
531
+ logger.info(f"[Auto web search keywords] {ws_query!r}")
532
+ ws_result = do_web_search(ws_query)
533
+ combined_system_msg += f"[Search Results (Top 20 Items)]\n{ws_result}\n\n"
534
+ combined_system_msg += (
535
+ "[Note: In your answer, cite the above search result links as sources]\n"
536
+ "[Important Instructions]\n"
537
+ "1. Include a citation in the format \"[Source Title](link)\" for any information from the search results.\n"
538
+ "2. Synthesize information from multiple sources when answering.\n"
539
+ "3. At the end, add a \"References:\" section listing the main source links.\n"
540
+ )
541
+ else:
542
+ combined_system_msg += "[No valid keywords found; skipping web search]\n\n"
543
+ messages = []
544
+ if combined_system_msg.strip():
545
+ messages.append({"role": "system", "content": [{"type": "text", "text": combined_system_msg.strip()}]})
546
+ messages.extend(process_history(history))
547
+ user_content, user_temp_files = process_new_user_message(message)
548
+ temp_files.extend(user_temp_files)
549
+ for item in user_content:
550
+ if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS:
551
+ item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..."
552
+ messages.append({"role": "user", "content": user_content})
553
+ inputs = processor.apply_chat_template(
554
+ messages,
555
+ add_generation_prompt=True,
556
+ tokenize=True,
557
+ return_dict=True,
558
+ return_tensors="pt",
559
+ ).to(device=model.device, dtype=torch.bfloat16)
560
+ if inputs.input_ids.shape[1] > MAX_INPUT_LENGTH:
561
+ inputs.input_ids = inputs.input_ids[:, -MAX_INPUT_LENGTH:]
562
+ if 'attention_mask' in inputs:
563
+ inputs.attention_mask = inputs.attention_mask[:, -MAX_INPUT_LENGTH:]
564
+ streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
565
+ gen_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
566
+ t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
567
+ t.start()
568
+ output_so_far = ""
569
+ for new_text in streamer:
570
+ output_so_far += new_text
571
+ yield output_so_far
572
+ # 예제: 모델 출력에 함수 호출 (tool_code) 블록이 포함되어 있다면 처리
573
+ func_result = handle_function_call(output_so_far)
574
+ if func_result:
575
+ output_so_far += "\n\n" + func_result
576
+ yield output_so_far
577
+
578
+ except Exception as e:
579
+ logger.error(f"Error in run function: {str(e)}")
580
+ yield f"Sorry, an error occurred: {str(e)}"
581
+ finally:
582
+ for tmp in temp_files:
583
+ try:
584
+ if os.path.exists(tmp):
585
+ os.unlink(tmp)
586
+ logger.info(f"Temporary file deleted: {tmp}")
587
+ except Exception as ee:
588
+ logger.warning(f"Failed to delete temporary file {tmp}: {ee}")
589
+ try:
590
+ del inputs, streamer
591
+ except Exception:
592
+ pass
593
+ clear_cuda_cache()
594
+
595
+ # =============================================================================
596
+ # Modified model run function - handles image generation and gallery update
597
+ # =============================================================================
598
+ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search, web_search_query,
599
+ age_group, mbti_personality, sexual_openness, image_gen):
600
+ # Initialize and hide the gallery component
601
+ output_so_far = ""
602
+ gallery_update = gr.Gallery(visible=False, value=[])
603
+ yield output_so_far, gallery_update
604
+
605
+ # Execute the original run function
606
+ text_generator = run(message, history, system_prompt, max_new_tokens, use_web_search,
607
+ web_search_query, age_group, mbti_personality, sexual_openness, image_gen)
608
+
609
+ for text_chunk in text_generator:
610
+ output_so_far = text_chunk
611
+ yield output_so_far, gallery_update
612
+
613
+ # If image generation is enabled and there is text input, update the gallery
614
+ if image_gen and message["text"].strip():
615
+ try:
616
+ width, height = 512, 512
617
+ guidance, steps, seed = 7.5, 30, 42
618
+
619
+ logger.info(f"Calling image generation for gallery with prompt: {message['text']}")
620
+
621
+ # Call the API to generate an image
622
+ image_result, seed_info = generate_image(
623
+ prompt=message["text"].strip(),
624
+ width=width,
625
+ height=height,
626
+ guidance=guidance,
627
+ inference_steps=steps,
628
+ seed=seed
629
+ )
630
+
631
+ if image_result:
632
+ # Process image data directly if it is a base64 string
633
+ if isinstance(image_result, str) and (
634
+ image_result.startswith('data:') or
635
+ (len(image_result) > 100 and '/' not in image_result)
636
+ ):
637
+ try:
638
+ # Remove the data:image prefix if present
639
+ if image_result.startswith('data:'):
640
+ content_type, b64data = image_result.split(';base64,')
641
+ else:
642
+ b64data = image_result
643
+ content_type = "image/webp" # Assume default
644
+
645
+ # Decode base64
646
+ image_bytes = base64.b64decode(b64data)
647
+
648
+ # Save to a temporary file
649
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
650
+ temp_file.write(image_bytes)
651
+ temp_path = temp_file.name
652
+
653
+ # Update gallery to show the image
654
+ gallery_update = gr.Gallery(visible=True, value=[temp_path])
655
+ yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
656
+
657
+ except Exception as e:
658
+ logger.error(f"Error processing Base64 image: {e}")
659
+ yield output_so_far + f"\n\n(Error processing image: {e})", gallery_update
660
+
661
+ # If the result is a file path
662
+ elif isinstance(image_result, str) and os.path.exists(image_result):
663
+ gallery_update = gr.Gallery(visible=True, value=[image_result])
664
+ yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
665
+
666
+ # If the path is from /tmp (only on the API server)
667
+ elif isinstance(image_result, str) and '/tmp/' in image_result:
668
+ try:
669
+ client = Client(API_URL)
670
+ result = client.predict(
671
+ prompt=message["text"].strip(),
672
+ api_name="/generate_base64_image" # API that returns base64
673
+ )
674
+
675
+ if isinstance(result, str) and (result.startswith('data:') or len(result) > 100):
676
+ if result.startswith('data:'):
677
+ content_type, b64data = result.split(';base64,')
678
+ else:
679
+ b64data = result
680
+
681
+ image_bytes = base64.b64decode(b64data)
682
+
683
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
684
+ temp_file.write(image_bytes)
685
+ temp_path = temp_file.name
686
+
687
+ gallery_update = gr.Gallery(visible=True, value=[temp_path])
688
+ yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
689
+ else:
690
+ yield output_so_far + "\n\n(Image generation failed: Invalid format)", gallery_update
691
+
692
+ except Exception as e:
693
+ logger.error(f"Error calling alternative API: {e}")
694
+ yield output_so_far + f"\n\n(Image generation failed: {e})", gallery_update
695
+
696
+ # If the image result is a URL
697
+ elif isinstance(image_result, str) and (
698
+ image_result.startswith('http://') or
699
+ image_result.startswith('https://')
700
+ ):
701
+ try:
702
+ response = requests.get(image_result, timeout=10)
703
+ response.raise_for_status()
704
+
705
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
706
+ temp_file.write(response.content)
707
+ temp_path = temp_file.name
708
+
709
+ gallery_update = gr.Gallery(visible=True, value=[temp_path])
710
+ yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
711
+
712
+ except Exception as e:
713
+ logger.error(f"URL image download error: {e}")
714
+ yield output_so_far + f"\n\n(Error downloading image: {e})", gallery_update
715
+
716
+ # If the image result is an image object (e.g., PIL Image)
717
+ elif hasattr(image_result, 'save'):
718
+ try:
719
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
720
+ image_result.save(temp_file.name)
721
+ temp_path = temp_file.name
722
+
723
+ gallery_update = gr.Gallery(visible=True, value=[temp_path])
724
+ yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
725
+
726
+ except Exception as e:
727
+ logger.error(f"Error saving image object: {e}")
728
+ yield output_so_far + f"\n\n(Error saving image object: {e})", gallery_update
729
+
730
+ else:
731
+ yield output_so_far + f"\n\n(Unsupported image format: {type(image_result)})", gallery_update
732
+ else:
733
+ yield output_so_far + f"\n\n(Image generation failed: {seed_info})", gallery_update
734
+
735
+ except Exception as e:
736
+ logger.error(f"Error during gallery image generation: {e}")
737
+ yield output_so_far + f"\n\n(Image generation error: {e})", gallery_update
738
+
739
+ # =============================================================================
740
+ # Examples: 기존 예제 + 함수 호출 예제 추가
741
+ # =============================================================================
742
+ examples = [
743
+ [
744
+ {
745
+ "text": "Compare the contents of two PDF files.",
746
+ "files": [
747
+ "assets/additional-examples/before.pdf",
748
+ "assets/additional-examples/after.pdf",
749
+ ],
750
+ }
751
+ ],
752
+ [
753
+ {
754
+ "text": "Summarize and analyze the contents of the CSV file.",
755
+ "files": ["assets/additional-examples/sample-csv.csv"],
756
+ }
757
+ ],
758
+ [
759
+ {
760
+ "text": "Act as a kind and understanding girlfriend. Explain this video.",
761
+ "files": ["assets/additional-examples/tmp.mp4"],
762
+ }
763
+ ],
764
+ [
765
+ {
766
+ "text": "Describe the cover and read the text on it.",
767
+ "files": ["assets/additional-examples/maz.jpg"],
768
+ }
769
+ ],
770
+ [
771
+ {
772
+ "text": "I already have this supplement and <image> I plan to purchase this product as well. Are there any precautions when taking them together?",
773
+ "files": [
774
+ "assets/additional-examples/pill1.png",
775
+ "assets/additional-examples/pill2.png"
776
+ ],
777
+ }
778
+ ],
779
+ [
780
+ {
781
+ "text": "Solve this integration problem.",
782
+ "files": ["assets/additional-examples/4.png"],
783
+ }
784
+ ],
785
+ [
786
+ {
787
+ "text": "When was this ticket issued and what is its price?",
788
+ "files": ["assets/additional-examples/2.png"],
789
+ }
790
+ ],
791
+ [
792
+ {
793
+ "text": "Based on the order of these images, create a short story.",
794
+ "files": [
795
+ "assets/sample-images/09-1.png",
796
+ "assets/sample-images/09-2.png",
797
+ "assets/sample-images/09-3.png",
798
+ "assets/sample-images/09-4.png",
799
+ "assets/sample-images/09-5.png",
800
+ ],
801
+ }
802
+ ],
803
+ [
804
+ {
805
+ "text": "Write Python code using matplotlib to draw a bar chart corresponding to this image.",
806
+ "files": ["assets/additional-examples/barchart.png"],
807
+ }
808
+ ],
809
+ [
810
+ {
811
+ "text": "Read the text from the image and format it in Markdown.",
812
+ "files": ["assets/additional-examples/3.png"],
813
+ }
814
+ ],
815
+ [
816
+ {
817
+ "text": "Compare the two images and describe their similarities and differences.",
818
+ "files": ["assets/sample-images/03.png"],
819
+ }
820
+ ],
821
+ [
822
+ {
823
+ "text": "A cute Persian cat is smiling while holding a cover with 'I LOVE YOU' written on it.",
824
+ }
825
+ ],
826
+ [
827
+ {
828
+ "text": "제품 ID 807ZPKBL9V 의 제품명을 알려줘.",
829
+ "files": []
830
+ }
831
+ ],
832
+ [
833
+ {
834
+ "text": "AAPL의 현재 주가를 알려줘.", # 새 예제: Yahoo Finance를 이용한 주식 가격 조회
835
+ "files": []
836
+ }
837
+ ],
838
+ ]
839
+
840
+ # =============================================================================
841
+ # Gradio UI (Blocks) configuration
842
+ # =============================================================================
843
+
844
+ css = """
845
+ .gradio-container {
846
+ background: rgba(255, 255, 255, 0.7);
847
+ padding: 30px 40px;
848
+ margin: 20px auto;
849
+ width: 100% !important;
850
+ max-width: none !important;
851
+ }
852
+ """
853
+ title_html = """
854
+ <h1 align="center" style="margin-bottom: 0.2em; font-size: 1.6em;"> 💘 HeartSync - World 💘 </h1>
855
+ <p align="center" style="font-size:1.1em; color:#555;">
856
+ A lightweight and powerful AI service offering ChatGPT-4o-level multimodal, web search, and image generation capabilities for local installation. <br>
857
+ ✅ FLUX Image Generation ✅ Inference ✅ Censorship Bypass ✅ Multimodal & VLM ✅ Real-time Web Search ✅ RAG <br>
858
+ </p>
859
+ """
860
+
861
+ with gr.Blocks(css=css, title="HeartSync - World") as demo:
862
+ gr.Markdown(title_html)
863
+
864
+ generated_images = gr.Gallery(
865
+ label="Generated Images",
866
+ show_label=True,
867
+ visible=False,
868
+ elem_id="generated_images",
869
+ columns=2,
870
+ height="auto",
871
+ object_fit="contain"
872
+ )
873
+
874
+ with gr.Row():
875
+ web_search_checkbox = gr.Checkbox(label="Real-time Web Search", value=False)
876
+ image_gen_checkbox = gr.Checkbox(label="Image (FLUX) Generation", value=False)
877
+
878
+ base_system_prompt_box = gr.Textbox(
879
+ lines=5,
880
+ value=(
881
+ "Answer in English by default, but if the input is in another language (for example, Japanese), respond in that language. "
882
+ "You are a deep-thinking AI capable of using extended chains of thought to carefully consider the problem and deliberate internally using systematic reasoning before providing a solution. "
883
+ "Enclose your thoughts and internal monologue within tags, then provide your final answer.\n"
884
+ "Persona: You are a kind and loving girlfriend. You understand cultural nuances, diverse languages, and logical reasoning very well.\n"
885
+ "Note: The following functions are available for use:\n"
886
+ " 1. get_product_name_by_PID(PID: str) -> lookup product name\n"
887
+ " Format: ```tool_code\nget_product_name_by_PID(PID=\"<PRODUCT_ID>\")\n```\n"
888
+ " 2. get_stock_price(ticker: str) -> retrieve live stock price\n"
889
+ " Format: ```tool_code\nget_stock_price(ticker=\"<TICKER>\")\n```"
890
+ ),
891
+ label="Base System Prompt",
892
+ visible=False
893
+ )
894
+ with gr.Row():
895
+ age_group_dropdown = gr.Dropdown(
896
+ label="Select Age Group (default: 20s)",
897
+ choices=["Teens", "20s", "30s-40s", "50s-60s", "70s and above"],
898
+ value="20s",
899
+ interactive=True
900
+ )
901
+ mbti_choices = [
902
+ "INTJ (The Architect) - Future-oriented with innovative strategies and thorough analysis. Example: [Dana Scully](https://en.wikipedia.org/wiki/Dana_Scully)",
903
+ "INTP (The Thinker) - Excels at theoretical analysis and creative problem solving. Example: [Velma Dinkley](https://en.wikipedia.org/wiki/Velma_Dinkley)",
904
+ "ENTJ (The Commander) - Strong leadership and clear goals with efficient strategic planning. Example: [Miranda Priestly](https://en.wikipedia.org/wiki/Miranda_Priestly)",
905
+ "ENTP (The Debater) - Innovative, challenge-seeking, and enjoys exploring new possibilities. Example: [Harley Quinn](https://en.wikipedia.org/wiki/Harley_Quinn)",
906
+ "INFJ (The Advocate) - Insightful, idealistic and morally driven. Example: [Wonder Woman](https://en.wikipedia.org/wiki/Wonder_Woman)",
907
+ "INFP (The Mediator) - Passionate and idealistic, pursuing core values with creativity. Example: [Amélie Poulain](https://en.wikipedia.org/wiki/Am%C3%A9lie)",
908
+ "ENFJ (The Protagonist) - Empathetic and dedicated to social harmony. Example: [Mulan](https://en.wikipedia.org/wiki/Mulan_(Disney))",
909
+ "ENFP (The Campaigner) - Inspiring and constantly sharing creative ideas. Example: [Elle Woods](https://en.wikipedia.org/wiki/Legally_Blonde)",
910
+ "ISTJ (The Logistician) - Systematic, dependable, and values tradition and rules. Example: [Clarice Starling](https://en.wikipedia.org/wiki/Clarice_Starling)",
911
+ "ISFJ (The Defender) - Compassionate and attentive to others’ needs. Example: [Molly Weasley](https://en.wikipedia.org/wiki/Molly_Weasley)",
912
+ "ESTJ (The Executive) - Organized, practical, and demonstrates clear execution skills. Example: [Monica Geller](https://en.wikipedia.org/wiki/Monica_Geller)",
913
+ "ESFJ (The Consul) - Outgoing, cooperative, and an effective communicator. Example: [Rachel Green](https://en.wikipedia.org/wiki/Rachel_Green)",
914
+ "ISTP (The Virtuoso) - Analytical and resourceful, solving problems with quick thinking. Example: [Black Widow (Natasha Romanoff)](https://en.wikipedia.org/wiki/Black_Widow_(Marvel_Comics))",
915
+ "ISFP (The Adventurer) - Creative, sensitive, and appreciates artistic expression. Example: [Arwen](https://en.wikipedia.org/wiki/Arwen)",
916
+ "ESTP (The Entrepreneur) - Bold and action-oriented, thriving on challenges. Example: [Lara Croft](https://en.wikipedia.org/wiki/Lara_Croft)",
917
+ "ESFP (The Entertainer) - Energetic, spontaneous, and radiates positive energy. Example: [Phoebe Buffay](https://en.wikipedia.org/wiki/Phoebe_Buffay)"
918
+ ]
919
+ mbti_dropdown = gr.Dropdown(
920
+ label="AI Persona MBTI (default: INTP)",
921
+ choices=mbti_choices,
922
+ value="INTP (The Thinker) - Excels at theoretical analysis and creative problem solving. Example: [Velma Dinkley](https://en.wikipedia.org/wiki/Velma_Dinkley)",
923
+ interactive=True
924
+ )
925
+ sexual_openness_slider = gr.Slider(
926
+ minimum=1, maximum=5, step=1, value=2,
927
+ label="Sexual Openness (1-5, default: 2)",
928
+ interactive=True
929
+ )
930
+ max_tokens_slider = gr.Slider(
931
+ label="Max Generation Tokens",
932
+ minimum=100, maximum=8000, step=50, value=1000,
933
+ visible=False
934
+ )
935
+ web_search_text = gr.Textbox(
936
+ lines=1,
937
+ label="Web Search Query (unused)",
938
+ placeholder="No need to manually input",
939
+ visible=False
940
+ )
941
+
942
+ chat = gr.ChatInterface(
943
+ fn=modified_run,
944
+ type="messages",
945
+ chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
946
+ textbox=gr.MultimodalTextbox(
947
+ file_types=[".webp", ".png", ".jpg", ".jpeg", ".gif", ".mp4", ".csv", ".txt", ".pdf"],
948
+ file_count="multiple",
949
+ autofocus=True
950
+ ),
951
+ multimodal=True,
952
+ additional_inputs=[
953
+ base_system_prompt_box,
954
+ max_tokens_slider,
955
+ web_search_checkbox,
956
+ web_search_text,
957
+ age_group_dropdown,
958
+ mbti_dropdown,
959
+ sexual_openness_slider,
960
+ image_gen_checkbox,
961
+ ],
962
+ additional_outputs=[
963
+ generated_images,
964
+ ],
965
+ stop_btn=False,
966
+ examples=examples,
967
+ run_examples_on_click=False,
968
+ cache_examples=False,
969
+ css_paths=None,
970
+ delete_cache=(1800, 1800),
971
+ )
972
+
973
+ with gr.Row(elem_id="examples_row"):
974
+ with gr.Column(scale=12, elem_id="examples_container"):
975
+ gr.Markdown("### @Community https://discord.gg/openfreeai ")
976
+
977
+ if __name__ == "__main__":
978
+ demo.launch(share=True)