ginipick commited on
Commit
4dcee53
·
verified ·
1 Parent(s): 81117e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -87
app.py CHANGED
@@ -13,7 +13,7 @@ import base64
13
  import logging
14
  import time
15
  from urllib.parse import quote # Added for URL encoding
16
- import importlib # NEW: For dynamic import
17
 
18
  import gradio as gr
19
  import spaces
@@ -84,7 +84,6 @@ def generate_image(prompt: str, width: float, height: float, guidance: float, in
84
  logging.error(f"Image generation failed: {str(e)}")
85
  return None, f"Error: {str(e)}"
86
 
87
- # Base64 padding fix function
88
  def fix_base64_padding(data):
89
  """Fix the padding of a Base64 string."""
90
  if isinstance(data, bytes):
@@ -99,18 +98,12 @@ def fix_base64_padding(data):
99
 
100
  return data
101
 
102
- # =============================================================================
103
- # Memory cleanup function
104
- # =============================================================================
105
  def clear_cuda_cache():
106
  """Explicitly clear the CUDA cache."""
107
  if torch.cuda.is_available():
108
  torch.cuda.empty_cache()
109
  gc.collect()
110
 
111
- # =============================================================================
112
- # SerpHouse related functions
113
- # =============================================================================
114
  SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
115
 
116
  def extract_keywords(text: str, top_k: int = 5) -> str:
@@ -176,9 +169,6 @@ Below are the search results. Use this information to answer the query:
176
  logger.error(f"Web search failed: {e}")
177
  return f"Web search failed: {str(e)}"
178
 
179
- # =============================================================================
180
- # Model and processor loading
181
- # =============================================================================
182
  MAX_CONTENT_CHARS = 2000
183
  MAX_INPUT_LENGTH = 2096
184
  model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
@@ -191,9 +181,6 @@ model = Gemma3ForConditionalGeneration.from_pretrained(
191
  )
192
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
193
 
194
- # =============================================================================
195
- # CSV, TXT, PDF analysis functions
196
- # =============================================================================
197
  def analyze_csv_file(path: str) -> str:
198
  try:
199
  df = pd.read_csv(path)
@@ -238,9 +225,6 @@ def pdf_to_markdown(pdf_path: str) -> str:
238
  full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
239
  return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
240
 
241
- # =============================================================================
242
- # Check media file limits
243
- # =============================================================================
244
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
245
  image_count = 0
246
  video_count = 0
@@ -293,9 +277,6 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
293
  return False
294
  return True
295
 
296
- # =============================================================================
297
- # Video processing functions
298
- # =============================================================================
299
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
300
  vidcap = cv2.VideoCapture(video_path)
301
  fps = vidcap.get(cv2.CAP_PROP_FPS)
@@ -328,9 +309,6 @@ def process_video(video_path: str) -> tuple[list[dict], list[str]]:
328
  content.append({"type": "image", "url": temp_file.name})
329
  return content, temp_files
330
 
331
- # =============================================================================
332
- # Interleaved <image> processing function
333
- # =============================================================================
334
  def process_interleaved_images(message: dict) -> list[dict]:
335
  parts = re.split(r"(<image>)", message["text"])
336
  content = []
@@ -347,9 +325,6 @@ def process_interleaved_images(message: dict) -> list[dict]:
347
  content.append({"type": "text", "text": part})
348
  return content
349
 
350
- # =============================================================================
351
- # File processing -> content creation
352
- # =============================================================================
353
  def is_image_file(file_path: str) -> bool:
354
  return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
355
 
@@ -390,9 +365,6 @@ def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]:
390
  content_list.append({"type": "image", "url": img_path})
391
  return content_list, temp_files
392
 
393
- # =============================================================================
394
- # Convert history to LLM messages
395
- # =============================================================================
396
  def process_history(history: list[dict]) -> list[dict]:
397
  messages = []
398
  current_user_content = []
@@ -416,9 +388,6 @@ def process_history(history: list[dict]) -> list[dict]:
416
  messages.append({"role": "user", "content": current_user_content})
417
  return messages
418
 
419
- # =============================================================================
420
- # Model generation function (with OOM catching)
421
- # =============================================================================
422
  def _model_gen_with_oom_catch(**kwargs):
423
  try:
424
  model.generate(**kwargs)
@@ -433,18 +402,10 @@ def _model_gen_with_oom_catch(**kwargs):
433
  def load_function_definitions(json_path="functions.json"):
434
  """
435
  로컬 JSON 파일에서 함수 정의 목록을 로드하여 반환.
436
- 각 항목: {
437
- "name": <str>,
438
- "description": <str>,
439
- "module_path": <str>,
440
- "func_name_in_module": <str>,
441
- "parameters": { ... }
442
- }
443
  """
444
  try:
445
  with open(json_path, "r", encoding="utf-8") as f:
446
  data = json.load(f)
447
- # name을 키로 하는 dict 형태로 재구성
448
  func_dict = {}
449
  for entry in data:
450
  func_name = entry["name"]
@@ -456,9 +417,6 @@ def load_function_definitions(json_path="functions.json"):
456
 
457
  FUNCTION_DEFINITIONS = load_function_definitions("functions.json")
458
 
459
- # =============================================================================
460
- # Dynamic handle_function_call
461
- # =============================================================================
462
  def handle_function_call(text: str) -> str:
463
  """
464
  Detects and processes function call blocks in the text using the JSON-based approach.
@@ -470,7 +428,6 @@ def handle_function_call(text: str) -> str:
470
  ```tool_code
471
  get_product_name_by_PID(PID="807ZPKBL9V")
472
  ```
473
- We parse that block, check if 'FUNCTION_DEFINITIONS' has an entry, then import & call it.
474
  """
475
  import re
476
  pattern = r"```tool_code\s*(.*?)\s*```"
@@ -479,12 +436,11 @@ def handle_function_call(text: str) -> str:
479
  return ""
480
  code_block = match.group(1).strip()
481
 
482
- # 함수명 추출 (예: get_stock_price)
483
- # 정규식: ^(\w+)\(.*\)
484
  func_match = re.match(r'^(\w+)\((.*)\)$', code_block)
485
  if not func_match:
486
  logger.debug("No valid function call format found.")
487
  return ""
 
488
  func_name = func_match.group(1)
489
  param_str = func_match.group(2).strip()
490
 
@@ -496,43 +452,35 @@ def handle_function_call(text: str) -> str:
496
  func_info = FUNCTION_DEFINITIONS[func_name]
497
  module_path = func_info["module_path"]
498
  module_func_name = func_info["func_name_in_module"]
499
- # 동적 임포트
500
  try:
501
  imported_module = importlib.import_module(module_path)
502
  except ImportError as e:
503
  logger.error(f"Failed to import module {module_path}: {e}")
504
  return f"```tool_output\nError: Cannot import module '{module_path}'\n```"
505
 
506
- # 실제 함수 객체를 가져옴
507
  if not hasattr(imported_module, module_func_name):
508
  logger.error(f"Module '{module_path}' has no attribute '{module_func_name}'.")
509
  return f"```tool_output\nError: Function '{module_func_name}' not found in module '{module_path}'\n```"
510
 
511
  real_func = getattr(imported_module, module_func_name)
512
 
513
- # 파라미터 파싱 예: ticker="AAPL", some_arg=123
514
- # 단순 정규식으로 key="value" or key=123 식을 구분
515
  param_pattern = r'(\w+)\s*=\s*"(.*?)"|(\w+)\s*=\s*([\d.]+)'
516
- # 이 정규식은 간단히 key="string" 또는 key=123 같은 형태를 파싱
517
- # 더 복잡한 경우 별도 파싱 로직이나 json.loads 기법 사용 필요
518
  param_dict = {}
519
  for p_match in re.finditer(param_pattern, param_str):
520
  if p_match.group(1) and p_match.group(2):
521
- # group(1)은 key, group(2)는 string value
522
  key = p_match.group(1)
523
  val = p_match.group(2)
524
  param_dict[key] = val
525
  else:
526
- # group(3)은 key, group(4)는 numeric value
527
  key = p_match.group(3)
528
  val = p_match.group(4)
529
- # 숫자 변환
530
  if '.' in val:
531
  param_dict[key] = float(val)
532
  else:
533
  param_dict[key] = int(val)
534
 
535
- # 이제 실제 함수 실행
536
  try:
537
  result = real_func(**param_dict)
538
  except Exception as e:
@@ -541,9 +489,6 @@ def handle_function_call(text: str) -> str:
541
 
542
  return f"```tool_output\n{result}\n```"
543
 
544
- # =============================================================================
545
- # Main inference function
546
- # =============================================================================
547
  @spaces.GPU(duration=120)
548
  def run(
549
  message: dict,
@@ -555,19 +500,18 @@ def run(
555
  age_group: str = "20s",
556
  mbti_personality: str = "INTP",
557
  sexual_openness: int = 2,
558
- image_gen: bool = False # "Image Gen" checkbox status
559
  ) -> Iterator[str]:
560
  if not validate_media_constraints(message, history):
561
  yield ""
562
  return
563
  temp_files = []
564
  try:
565
- # JSON에서 로드된 함수 목록을 요약해서 시스템 프롬프트에 포함할 수도 있음
566
- # (토큰 부담이 커질 수 있으므로, 적당히 압축 요약 권장)
567
- # 아래는 예시로 간단히 함수 이름만 나열
568
  available_funcs_text = ""
569
  for f_name, info in FUNCTION_DEFINITIONS.items():
570
- available_funcs_text += f"Function: {f_name} - {info['description']}\n"
 
571
 
572
  persona = (
573
  f"{system_prompt.strip()}\n\n"
@@ -575,7 +519,9 @@ def run(
575
  f"Age Group: {age_group}\n"
576
  f"MBTI Persona: {mbti_personality}\n"
577
  f"Sexual Openness (1-5): {sexual_openness}\n\n"
578
- "Below are the available functions you can call (use the format: ```tool_code\\nfunc_name(param=...)\n```):\n"
 
 
579
  f"{available_funcs_text}\n"
580
  )
581
  combined_system_msg = f"[System Prompt]\n{persona.strip()}\n\n"
@@ -629,7 +575,6 @@ def run(
629
  output_so_far += new_text
630
  yield output_so_far
631
 
632
- # 모델 출력 중 ```tool_code``` 블록이 있으면 처리
633
  func_result = handle_function_call(output_so_far)
634
  if func_result:
635
  output_so_far += "\n\n" + func_result
@@ -652,17 +597,12 @@ def run(
652
  pass
653
  clear_cuda_cache()
654
 
655
- # =============================================================================
656
- # Modified model run function - handles image generation and gallery update
657
- # =============================================================================
658
  def modified_run(message, history, system_prompt, max_new_tokens, use_web_search, web_search_query,
659
  age_group, mbti_personality, sexual_openness, image_gen):
660
- # Initialize and hide the gallery component
661
  output_so_far = ""
662
  gallery_update = gr.Gallery(visible=False, value=[])
663
  yield output_so_far, gallery_update
664
 
665
- # Execute the original run function
666
  text_generator = run(message, history, system_prompt, max_new_tokens, use_web_search,
667
  web_search_query, age_group, mbti_personality, sexual_openness, image_gen)
668
 
@@ -670,15 +610,12 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
670
  output_so_far = text_chunk
671
  yield output_so_far, gallery_update
672
 
673
- # If image generation is enabled and there is text input, update the gallery
674
  if image_gen and message["text"].strip():
675
  try:
676
  width, height = 512, 512
677
  guidance, steps, seed = 7.5, 30, 42
678
 
679
  logger.info(f"Calling image generation for gallery with prompt: {message['text']}")
680
-
681
- # Call the API to generate an image
682
  image_result, seed_info = generate_image(
683
  prompt=message["text"].strip(),
684
  width=width,
@@ -687,7 +624,6 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
687
  inference_steps=steps,
688
  seed=seed
689
  )
690
-
691
  if image_result:
692
  if isinstance(image_result, str) and (
693
  image_result.startswith('data:') or
@@ -699,22 +635,18 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
699
  else:
700
  b64data = image_result
701
  content_type = "image/webp"
702
-
703
  image_bytes = base64.b64decode(b64data)
704
  with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
705
  temp_file.write(image_bytes)
706
  temp_path = temp_file.name
707
  gallery_update = gr.Gallery(visible=True, value=[temp_path])
708
  yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
709
-
710
  except Exception as e:
711
  logger.error(f"Error processing Base64 image: {e}")
712
  yield output_so_far + f"\n\n(Error processing image: {e})", gallery_update
713
-
714
  elif isinstance(image_result, str) and os.path.exists(image_result):
715
  gallery_update = gr.Gallery(visible=True, value=[image_result])
716
  yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
717
-
718
  elif isinstance(image_result, str) and '/tmp/' in image_result:
719
  try:
720
  client = Client(API_URL)
@@ -722,13 +654,11 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
722
  prompt=message["text"].strip(),
723
  api_name="/generate_base64_image"
724
  )
725
-
726
  if isinstance(result, str) and (result.startswith('data:') or len(result) > 100):
727
  if result.startswith('data:'):
728
  content_type, b64data = result.split(';base64,')
729
  else:
730
  b64data = result
731
-
732
  image_bytes = base64.b64decode(b64data)
733
  with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
734
  temp_file.write(image_bytes)
@@ -737,7 +667,6 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
737
  yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
738
  else:
739
  yield output_so_far + "\n\n(Image generation failed: Invalid format)", gallery_update
740
-
741
  except Exception as e:
742
  logger.error(f"Error calling alternative API: {e}")
743
  yield output_so_far + f"\n\n(Image generation failed: {e})", gallery_update
@@ -755,14 +684,10 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
755
  yield output_so_far + f"\n\n(Unsupported image format: {type(image_result)})", gallery_update
756
  else:
757
  yield output_so_far + f"\n\n(Image generation failed: {seed_info})", gallery_update
758
-
759
  except Exception as e:
760
  logger.error(f"Error during gallery image generation: {e}")
761
  yield output_so_far + f"\n\n(Image generation error: {e})", gallery_update
762
 
763
- # =============================================================================
764
- # Examples
765
- # =============================================================================
766
  examples = [
767
  [
768
  {
@@ -855,7 +780,7 @@ examples = [
855
  ],
856
  [
857
  {
858
- "text": "AAPL의 현재 주가를 알려줘.",
859
  "files": []
860
  }
861
  ],
 
13
  import logging
14
  import time
15
  from urllib.parse import quote # Added for URL encoding
16
+ import importlib # For dynamic import
17
 
18
  import gradio as gr
19
  import spaces
 
84
  logging.error(f"Image generation failed: {str(e)}")
85
  return None, f"Error: {str(e)}"
86
 
 
87
  def fix_base64_padding(data):
88
  """Fix the padding of a Base64 string."""
89
  if isinstance(data, bytes):
 
98
 
99
  return data
100
 
 
 
 
101
  def clear_cuda_cache():
102
  """Explicitly clear the CUDA cache."""
103
  if torch.cuda.is_available():
104
  torch.cuda.empty_cache()
105
  gc.collect()
106
 
 
 
 
107
  SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
108
 
109
  def extract_keywords(text: str, top_k: int = 5) -> str:
 
169
  logger.error(f"Web search failed: {e}")
170
  return f"Web search failed: {str(e)}"
171
 
 
 
 
172
  MAX_CONTENT_CHARS = 2000
173
  MAX_INPUT_LENGTH = 2096
174
  model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
 
181
  )
182
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
183
 
 
 
 
184
  def analyze_csv_file(path: str) -> str:
185
  try:
186
  df = pd.read_csv(path)
 
225
  full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
226
  return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
227
 
 
 
 
228
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
229
  image_count = 0
230
  video_count = 0
 
277
  return False
278
  return True
279
 
 
 
 
280
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
281
  vidcap = cv2.VideoCapture(video_path)
282
  fps = vidcap.get(cv2.CAP_PROP_FPS)
 
309
  content.append({"type": "image", "url": temp_file.name})
310
  return content, temp_files
311
 
 
 
 
312
  def process_interleaved_images(message: dict) -> list[dict]:
313
  parts = re.split(r"(<image>)", message["text"])
314
  content = []
 
325
  content.append({"type": "text", "text": part})
326
  return content
327
 
 
 
 
328
  def is_image_file(file_path: str) -> bool:
329
  return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
330
 
 
365
  content_list.append({"type": "image", "url": img_path})
366
  return content_list, temp_files
367
 
 
 
 
368
  def process_history(history: list[dict]) -> list[dict]:
369
  messages = []
370
  current_user_content = []
 
388
  messages.append({"role": "user", "content": current_user_content})
389
  return messages
390
 
 
 
 
391
  def _model_gen_with_oom_catch(**kwargs):
392
  try:
393
  model.generate(**kwargs)
 
402
  def load_function_definitions(json_path="functions.json"):
403
  """
404
  로컬 JSON 파일에서 함수 정의 목록을 로드하여 반환.
 
 
 
 
 
 
 
405
  """
406
  try:
407
  with open(json_path, "r", encoding="utf-8") as f:
408
  data = json.load(f)
 
409
  func_dict = {}
410
  for entry in data:
411
  func_name = entry["name"]
 
417
 
418
  FUNCTION_DEFINITIONS = load_function_definitions("functions.json")
419
 
 
 
 
420
  def handle_function_call(text: str) -> str:
421
  """
422
  Detects and processes function call blocks in the text using the JSON-based approach.
 
428
  ```tool_code
429
  get_product_name_by_PID(PID="807ZPKBL9V")
430
  ```
 
431
  """
432
  import re
433
  pattern = r"```tool_code\s*(.*?)\s*```"
 
436
  return ""
437
  code_block = match.group(1).strip()
438
 
 
 
439
  func_match = re.match(r'^(\w+)\((.*)\)$', code_block)
440
  if not func_match:
441
  logger.debug("No valid function call format found.")
442
  return ""
443
+
444
  func_name = func_match.group(1)
445
  param_str = func_match.group(2).strip()
446
 
 
452
  func_info = FUNCTION_DEFINITIONS[func_name]
453
  module_path = func_info["module_path"]
454
  module_func_name = func_info["func_name_in_module"]
455
+
456
  try:
457
  imported_module = importlib.import_module(module_path)
458
  except ImportError as e:
459
  logger.error(f"Failed to import module {module_path}: {e}")
460
  return f"```tool_output\nError: Cannot import module '{module_path}'\n```"
461
 
 
462
  if not hasattr(imported_module, module_func_name):
463
  logger.error(f"Module '{module_path}' has no attribute '{module_func_name}'.")
464
  return f"```tool_output\nError: Function '{module_func_name}' not found in module '{module_path}'\n```"
465
 
466
  real_func = getattr(imported_module, module_func_name)
467
 
468
+ # 간단 파라미터 파싱 (key="value" or key=123)
 
469
  param_pattern = r'(\w+)\s*=\s*"(.*?)"|(\w+)\s*=\s*([\d.]+)'
 
 
470
  param_dict = {}
471
  for p_match in re.finditer(param_pattern, param_str):
472
  if p_match.group(1) and p_match.group(2):
 
473
  key = p_match.group(1)
474
  val = p_match.group(2)
475
  param_dict[key] = val
476
  else:
 
477
  key = p_match.group(3)
478
  val = p_match.group(4)
 
479
  if '.' in val:
480
  param_dict[key] = float(val)
481
  else:
482
  param_dict[key] = int(val)
483
 
 
484
  try:
485
  result = real_func(**param_dict)
486
  except Exception as e:
 
489
 
490
  return f"```tool_output\n{result}\n```"
491
 
 
 
 
492
  @spaces.GPU(duration=120)
493
  def run(
494
  message: dict,
 
500
  age_group: str = "20s",
501
  mbti_personality: str = "INTP",
502
  sexual_openness: int = 2,
503
+ image_gen: bool = False
504
  ) -> Iterator[str]:
505
  if not validate_media_constraints(message, history):
506
  yield ""
507
  return
508
  temp_files = []
509
  try:
510
+ # JSON에서 로드된 함수 정보 문자열화 (예: 함수명과 example_usage만)
 
 
511
  available_funcs_text = ""
512
  for f_name, info in FUNCTION_DEFINITIONS.items():
513
+ example_usage = info.get("example_usage", "")
514
+ available_funcs_text += f"\n\nFunction: {f_name}\nDescription: {info['description']}\nExample:\n{example_usage}\n"
515
 
516
  persona = (
517
  f"{system_prompt.strip()}\n\n"
 
519
  f"Age Group: {age_group}\n"
520
  f"MBTI Persona: {mbti_personality}\n"
521
  f"Sexual Openness (1-5): {sexual_openness}\n\n"
522
+ "Below are the available functions you can call.\n"
523
+ "Important: Use the format exactly like: ```tool_code\nfunctionName(param=\"string\", ...)\n```\n"
524
+ "(Strings must be in double quotes)\n"
525
  f"{available_funcs_text}\n"
526
  )
527
  combined_system_msg = f"[System Prompt]\n{persona.strip()}\n\n"
 
575
  output_so_far += new_text
576
  yield output_so_far
577
 
 
578
  func_result = handle_function_call(output_so_far)
579
  if func_result:
580
  output_so_far += "\n\n" + func_result
 
597
  pass
598
  clear_cuda_cache()
599
 
 
 
 
600
  def modified_run(message, history, system_prompt, max_new_tokens, use_web_search, web_search_query,
601
  age_group, mbti_personality, sexual_openness, image_gen):
 
602
  output_so_far = ""
603
  gallery_update = gr.Gallery(visible=False, value=[])
604
  yield output_so_far, gallery_update
605
 
 
606
  text_generator = run(message, history, system_prompt, max_new_tokens, use_web_search,
607
  web_search_query, age_group, mbti_personality, sexual_openness, image_gen)
608
 
 
610
  output_so_far = text_chunk
611
  yield output_so_far, gallery_update
612
 
 
613
  if image_gen and message["text"].strip():
614
  try:
615
  width, height = 512, 512
616
  guidance, steps, seed = 7.5, 30, 42
617
 
618
  logger.info(f"Calling image generation for gallery with prompt: {message['text']}")
 
 
619
  image_result, seed_info = generate_image(
620
  prompt=message["text"].strip(),
621
  width=width,
 
624
  inference_steps=steps,
625
  seed=seed
626
  )
 
627
  if image_result:
628
  if isinstance(image_result, str) and (
629
  image_result.startswith('data:') or
 
635
  else:
636
  b64data = image_result
637
  content_type = "image/webp"
 
638
  image_bytes = base64.b64decode(b64data)
639
  with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
640
  temp_file.write(image_bytes)
641
  temp_path = temp_file.name
642
  gallery_update = gr.Gallery(visible=True, value=[temp_path])
643
  yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
 
644
  except Exception as e:
645
  logger.error(f"Error processing Base64 image: {e}")
646
  yield output_so_far + f"\n\n(Error processing image: {e})", gallery_update
 
647
  elif isinstance(image_result, str) and os.path.exists(image_result):
648
  gallery_update = gr.Gallery(visible=True, value=[image_result])
649
  yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
 
650
  elif isinstance(image_result, str) and '/tmp/' in image_result:
651
  try:
652
  client = Client(API_URL)
 
654
  prompt=message["text"].strip(),
655
  api_name="/generate_base64_image"
656
  )
 
657
  if isinstance(result, str) and (result.startswith('data:') or len(result) > 100):
658
  if result.startswith('data:'):
659
  content_type, b64data = result.split(';base64,')
660
  else:
661
  b64data = result
 
662
  image_bytes = base64.b64decode(b64data)
663
  with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
664
  temp_file.write(image_bytes)
 
667
  yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
668
  else:
669
  yield output_so_far + "\n\n(Image generation failed: Invalid format)", gallery_update
 
670
  except Exception as e:
671
  logger.error(f"Error calling alternative API: {e}")
672
  yield output_so_far + f"\n\n(Image generation failed: {e})", gallery_update
 
684
  yield output_so_far + f"\n\n(Unsupported image format: {type(image_result)})", gallery_update
685
  else:
686
  yield output_so_far + f"\n\n(Image generation failed: {seed_info})", gallery_update
 
687
  except Exception as e:
688
  logger.error(f"Error during gallery image generation: {e}")
689
  yield output_so_far + f"\n\n(Image generation error: {e})", gallery_update
690
 
 
 
 
691
  examples = [
692
  [
693
  {
 
780
  ],
781
  [
782
  {
783
+ "text": "AAPL의 현재 주가를 알려줘.",
784
  "files": []
785
  }
786
  ],