#!/usr/bin/env python import os import re import tempfile from collections.abc import Iterator from threading import Thread import cv2 import gradio as gr import spaces import torch from loguru import logger from PIL import Image from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer # CSV/TXT 분석 import pandas as pd # PDF 텍스트 추출 import PyPDF2 MAX_CONTENT_CHARS = 8000 # 너무 큰 파일을 막기 위해 최대 표시 8000자 model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it") processor = AutoProcessor.from_pretrained(model_id, padding_side="left") model = Gemma3ForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" ) MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5")) ################################################## # CSV, TXT, PDF 분석 함수 ################################################## def analyze_csv_file(path: str) -> str: """ CSV 파일을 전체 문자열로 변환. 너무 길 경우 일부만 표시. """ try: df = pd.read_csv(path) df_str = df.to_string() if len(df_str) > MAX_CONTENT_CHARS: df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[CSV File: {os.path.basename(path)}]**\n\n{df_str}" except Exception as e: return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}" def analyze_txt_file(path: str) -> str: """ TXT 파일 전문 읽기. 너무 길면 일부만 표시. """ try: with open(path, "r", encoding="utf-8") as f: text = f.read() if len(text) > MAX_CONTENT_CHARS: text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[TXT File: {os.path.basename(path)}]**\n\n{text}" except Exception as e: return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}" def pdf_to_markdown(pdf_path: str) -> str: """ PDF → Markdown. 페이지별로 간단히 텍스트 추출. """ text_chunks = [] try: with open(pdf_path, "rb") as f: reader = PyPDF2.PdfReader(f) for page_num, page in enumerate(reader.pages, start=1): page_text = page.extract_text() or "" page_text = page_text.strip() if page_text: text_chunks.append(f"## Page {page_num}\n\n{page_text}\n") except Exception as e: return f"Failed to read PDF ({os.path.basename(pdf_path)}): {str(e)}" full_text = "\n".join(text_chunks) if len(full_text) > MAX_CONTENT_CHARS: full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}" ################################################## # 이미지/비디오 업로드 제한 검사 ################################################## def count_files_in_new_message(paths: list[str]) -> tuple[int, int]: image_count = 0 video_count = 0 for path in paths: if path.endswith(".mp4"): video_count += 1 else: image_count += 1 return image_count, video_count def count_files_in_history(history: list[dict]) -> tuple[int, int]: image_count = 0 video_count = 0 for item in history: if item["role"] != "user" or isinstance(item["content"], str): continue if item["content"][0].endswith(".mp4"): video_count += 1 else: image_count += 1 return image_count, video_count def validate_media_constraints(message: dict, history: list[dict]) -> bool: """ - 비디오 1개 초과 불가 - 비디오와 이미지 혼합 불가 - 이미지 개수 MAX_NUM_IMAGES 초과 불가 - 태그가 있으면 태그 수와 실제 이미지 수 일치 - CSV, TXT, PDF 등은 여기서 제한하지 않음 """ media_files = [] for f in message["files"]: # 이미지: png/jpg/jpeg/gif/webp # 비디오: mp4 # cf) PDF, CSV, TXT 등은 제외 if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4"): media_files.append(f) new_image_count, new_video_count = count_files_in_new_message(media_files) history_image_count, history_video_count = count_files_in_history(history) image_count = history_image_count + new_image_count video_count = history_video_count + new_video_count if video_count > 1: gr.Warning("Only one video is supported.") return False if video_count == 1: if image_count > 0: gr.Warning("Mixing images and videos is not allowed.") return False if "" in message["text"]: gr.Warning("Using tags with video files is not supported.") return False if video_count == 0 and image_count > MAX_NUM_IMAGES: gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.") return False if "" in message["text"] and message["text"].count("") != new_image_count: gr.Warning("The number of tags in the text does not match the number of images.") return False return True ################################################## # 비디오 처리 ################################################## def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]: vidcap = cv2.VideoCapture(video_path) fps = vidcap.get(cv2.CAP_PROP_FPS) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) frame_interval = int(fps / 3) frames = [] for i in range(0, total_frames, frame_interval): vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames def process_video(video_path: str) -> list[dict]: content = [] frames = downsample_video(video_path) for frame in frames: pil_image, timestamp = frame with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: pil_image.save(temp_file.name) content.append({"type": "text", "text": f"Frame {timestamp}:"}) content.append({"type": "image", "url": temp_file.name}) logger.debug(f"{content=}") return content ################################################## # interleaved 처리 ################################################## def process_interleaved_images(message: dict) -> list[dict]: parts = re.split(r"()", message["text"]) content = [] image_index = 0 for part in parts: if part == "": content.append({"type": "image", "url": message["files"][image_index]}) image_index += 1 elif part.strip(): content.append({"type": "text", "text": part.strip()}) else: # 공백이거나 \n 같은 경우 if isinstance(part, str) and part != "": content.append({"type": "text", "text": part}) return content ################################################## # PDF + CSV + TXT + 이미지/비디오 ################################################## def process_new_user_message(message: dict) -> list[dict]: if not message["files"]: return [{"type": "text", "text": message["text"]}] # 1) 파일 분류 video_files = [f for f in message["files"] if f.endswith(".mp4")] image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)] csv_files = [f for f in message["files"] if f.lower().endswith(".csv")] txt_files = [f for f in message["files"] if f.lower().endswith(".txt")] pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")] # 2) 사용자 원본 text 추가 content_list = [{"type": "text", "text": message["text"]}] # 3) CSV for csv_path in csv_files: csv_analysis = analyze_csv_file(csv_path) content_list.append({"type": "text", "text": csv_analysis}) # 4) TXT for txt_path in txt_files: txt_analysis = analyze_txt_file(txt_path) content_list.append({"type": "text", "text": txt_analysis}) # 5) PDF for pdf_path in pdf_files: pdf_markdown = pdf_to_markdown(pdf_path) content_list.append({"type": "text", "text": pdf_markdown}) # 6) 비디오 (한 개만 허용) if video_files: content_list += process_video(video_files[0]) return content_list # 7) 이미지 처리 if "" in message["text"]: # interleaved return process_interleaved_images(message) else: # 일반 여러 장 for img_path in image_files: content_list.append({"type": "image", "url": img_path}) return content_list ################################################## # history -> LLM 메시지 변환 ################################################## def process_history(history: list[dict]) -> list[dict]: messages = [] current_user_content: list[dict] = [] for item in history: if item["role"] == "assistant": # user_content가 쌓여있다면 user 메시지로 저장 if current_user_content: messages.append({"role": "user", "content": current_user_content}) current_user_content = [] # 그 뒤 item은 assistant messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]}) else: # user content = item["content"] if isinstance(content, str): current_user_content.append({"type": "text", "text": content}) else: # 이미지나 기타 current_user_content.append({"type": "image", "url": content[0]}) return messages ################################################## # 메인 추론 함수 ################################################## @spaces.GPU(duration=120) def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]: if not validate_media_constraints(message, history): yield "" return messages = [] if system_prompt: messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) messages.extend(process_history(history)) messages.append({"role": "user", "content": process_new_user_message(message)}) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(device=model.device, dtype=torch.bfloat16) streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, ) t = Thread(target=model.generate, kwargs=gen_kwargs) t.start() output = "" for new_text in streamer: output += new_text yield output ################################################## # 예시들 (기존) ################################################## examples = [ [ { "text": "Test with PDF", "files": ["assets/sample.pdf"], } ], [ { "text": "Simple text with CSV upload.", "files": ["assets/sample.csv"], } ], # ...원래 예시들 유지... ] demo = gr.ChatInterface( fn=run, type="messages", chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]), # .webp, .png, .jpg, .jpeg, .gif, .mp4, .csv, .txt, .pdf 모두 허용 textbox=gr.MultimodalTextbox( file_types=[ ".webp", ".png", ".jpg", ".jpeg", ".gif", ".mp4", ".csv", ".txt", ".pdf" ], file_count="multiple", autofocus=True ), multimodal=True, additional_inputs=[ gr.Textbox( label="System Prompt", value=( "You are a deeply thoughtful AI. Consider problems thoroughly and derive " "correct solutions through systematic reasoning. Please answer in korean." ) ), gr.Slider(label="Max New Tokens", minimum=100, maximum=8000, step=50, value=2000), ], stop_btn=False, title="Gemma 3 27B IT", examples=examples, run_examples_on_click=False, cache_examples=False, css_paths="style.css", delete_cache=(1800, 1800), ) if __name__ == "__main__": demo.launch()