#!/usr/bin/env python import os import re import tempfile import gc # garbage collector from collections.abc import Iterator from threading import Thread import json import requests import cv2 import gradio as gr import spaces import torch from loguru import logger from PIL import Image from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer # CSV/TXT analysis import pandas as pd # PDF text extraction import PyPDF2 ############################################################################## # Memory cleanup function ############################################################################## def clear_cuda_cache(): """Clear CUDA cache explicitly.""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() ############################################################################## # SERPHouse API key from environment variable ############################################################################## SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "") ############################################################################## # Simple keyword extraction function ############################################################################## def extract_keywords(text: str, top_k: int = 5) -> str: """ Extract keywords from text """ text = re.sub(r"[^a-zA-Z0-9가-힣\s]", "", text) tokens = text.split() key_tokens = tokens[:top_k] return " ".join(key_tokens) ############################################################################## # SerpHouse Live endpoint call ############################################################################## def do_web_search(query: str) -> str: """ Return top 20 'organic' results as JSON string """ try: url = "https://api.serphouse.com/serp/live" # 기본 GET 방식으로 파라미터 간소화하고 결과 수를 20개로 제한 params = { "q": query, "domain": "google.com", "serp_type": "web", # Basic web search "device": "desktop", "lang": "en", "num": "20" # Request max 20 results } headers = { "Authorization": f"Bearer {SERPHOUSE_API_KEY}" } logger.info(f"SerpHouse API call... query: {query}") logger.info(f"Request URL: {url} - params: {params}") # GET request response = requests.get(url, headers=headers, params=params, timeout=60) response.raise_for_status() logger.info(f"SerpHouse API response status: {response.status_code}") data = response.json() # Handle various response structures results = data.get("results", {}) organic = None # Possible response structure 1 if isinstance(results, dict) and "organic" in results: organic = results["organic"] # Possible response structure 2 (nested results) elif isinstance(results, dict) and "results" in results: if isinstance(results["results"], dict) and "organic" in results["results"]: organic = results["results"]["organic"] # Possible response structure 3 (top-level organic) elif "organic" in data: organic = data["organic"] if not organic: logger.warning("No organic results found in response.") logger.debug(f"Response structure: {list(data.keys())}") if isinstance(results, dict): logger.debug(f"results structure: {list(results.keys())}") return "No web search results found or unexpected API response structure." # Limit results and optimize context length max_results = min(20, len(organic)) limited_organic = organic[:max_results] # Format results for better readability summary_lines = [] for idx, item in enumerate(limited_organic, start=1): title = item.get("title", "No title") link = item.get("link", "#") snippet = item.get("snippet", "No description") displayed_link = item.get("displayed_link", link) # Markdown format summary_lines.append( f"### Result {idx}: {title}\n\n" f"{snippet}\n\n" f"**Source**: [{displayed_link}]({link})\n\n" f"---\n" ) # Add simple instructions for model instructions = """ # X-RAY Security Scanning Reference Results Use this information to enhance your analysis. """ search_results = instructions + "\n".join(summary_lines) logger.info(f"Processed {len(limited_organic)} search results") return search_results except Exception as e: logger.error(f"Web search failed: {e}") return f"Web search failed: {str(e)}" ############################################################################## # Model/Processor loading ############################################################################## MAX_CONTENT_CHARS = 2000 MAX_INPUT_LENGTH = 2096 # Max input token limit model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B") processor = AutoProcessor.from_pretrained(model_id, padding_side="left") model = Gemma3ForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" # Change to "flash_attention_2" if available ) MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5")) ############################################################################## # CSV, TXT, PDF analysis functions ############################################################################## def analyze_csv_file(path: str) -> str: """ Convert CSV file to string. Truncate if too long. """ try: df = pd.read_csv(path) if df.shape[0] > 50 or df.shape[1] > 10: df = df.iloc[:50, :10] df_str = df.to_string() if len(df_str) > MAX_CONTENT_CHARS: df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[CSV File: {os.path.basename(path)}]**\n\n{df_str}" except Exception as e: return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}" def analyze_txt_file(path: str) -> str: """ Read TXT file. Truncate if too long. """ try: with open(path, "r", encoding="utf-8") as f: text = f.read() if len(text) > MAX_CONTENT_CHARS: text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[TXT File: {os.path.basename(path)}]**\n\n{text}" except Exception as e: return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}" def pdf_to_markdown(pdf_path: str) -> str: """ Convert PDF text to Markdown. Extract text by pages. """ text_chunks = [] try: with open(pdf_path, "rb") as f: reader = PyPDF2.PdfReader(f) max_pages = min(5, len(reader.pages)) for page_num in range(max_pages): page = reader.pages[page_num] page_text = page.extract_text() or "" page_text = page_text.strip() if page_text: if len(page_text) > MAX_CONTENT_CHARS // max_pages: page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(truncated)" text_chunks.append(f"## Page {page_num+1}\n\n{page_text}\n") if len(reader.pages) > max_pages: text_chunks.append(f"\n...(Showing {max_pages} of {len(reader.pages)} pages)...") except Exception as e: return f"Failed to read PDF ({os.path.basename(pdf_path)}): {str(e)}" full_text = "\n".join(text_chunks) if len(full_text) > MAX_CONTENT_CHARS: full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}" ############################################################################## # Image/Video upload limit check ############################################################################## def count_files_in_new_message(paths: list[str]) -> tuple[int, int]: image_count = 0 video_count = 0 for path in paths: if path.endswith(".mp4"): video_count += 1 elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", path, re.IGNORECASE): image_count += 1 return image_count, video_count def count_files_in_history(history: list[dict]) -> tuple[int, int]: image_count = 0 video_count = 0 for item in history: if item["role"] != "user" or isinstance(item["content"], str): continue if isinstance(item["content"], list) and len(item["content"]) > 0: file_path = item["content"][0] if isinstance(file_path, str): if file_path.endswith(".mp4"): video_count += 1 elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE): image_count += 1 return image_count, video_count def validate_media_constraints(message: dict, history: list[dict]) -> bool: media_files = [] for f in message["files"]: if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4"): media_files.append(f) new_image_count, new_video_count = count_files_in_new_message(media_files) history_image_count, history_video_count = count_files_in_history(history) image_count = history_image_count + new_image_count video_count = history_video_count + new_video_count if video_count > 1: gr.Warning("Only one video is supported.") return False if video_count == 1: if image_count > 0: gr.Warning("Mixing images and videos is not allowed.") return False if "" in message["text"]: gr.Warning("Using tags with video files is not supported.") return False if video_count == 0 and image_count > MAX_NUM_IMAGES: gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.") return False if "" in message["text"]: image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)] image_tag_count = message["text"].count("") if image_tag_count != len(image_files): gr.Warning("The number of tags in the text does not match the number of image files.") return False return True ############################################################################## # Video processing - with temp file tracking ############################################################################## def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]: vidcap = cv2.VideoCapture(video_path) fps = vidcap.get(cv2.CAP_PROP_FPS) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) frame_interval = max(int(fps), int(total_frames / 10)) frames = [] for i in range(0, total_frames, frame_interval): vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Resize image image = cv2.resize(image, (0, 0), fx=0.5, fy=0.5) pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) if len(frames) >= 5: break vidcap.release() return frames def process_video(video_path: str) -> tuple[list[dict], list[str]]: content = [] temp_files = [] # List for tracking temp files frames = downsample_video(video_path) for frame in frames: pil_image, timestamp = frame with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: pil_image.save(temp_file.name) temp_files.append(temp_file.name) # Track for deletion later content.append({"type": "text", "text": f"Frame {timestamp}:"}) content.append({"type": "image", "url": temp_file.name}) return content, temp_files ############################################################################## # interleaved processing ############################################################################## def process_interleaved_images(message: dict) -> list[dict]: parts = re.split(r"()", message["text"]) content = [] image_index = 0 image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)] for part in parts: if part == "" and image_index < len(image_files): content.append({"type": "image", "url": image_files[image_index]}) image_index += 1 elif part.strip(): content.append({"type": "text", "text": part.strip()}) else: if isinstance(part, str) and part != "": content.append({"type": "text", "text": part}) return content ############################################################################## # PDF + CSV + TXT + Image/Video ############################################################################## def is_image_file(file_path: str) -> bool: return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE)) def is_video_file(file_path: str) -> bool: return file_path.endswith(".mp4") def is_document_file(file_path: str) -> bool: return ( file_path.lower().endswith(".pdf") or file_path.lower().endswith(".csv") or file_path.lower().endswith(".txt") ) def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]: temp_files = [] # List for tracking temp files if not message["files"]: return [{"type": "text", "text": message["text"]}], temp_files video_files = [f for f in message["files"] if is_video_file(f)] image_files = [f for f in message["files"] if is_image_file(f)] csv_files = [f for f in message["files"] if f.lower().endswith(".csv")] txt_files = [f for f in message["files"] if f.lower().endswith(".txt")] pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")] content_list = [{"type": "text", "text": message["text"]}] for csv_path in csv_files: csv_analysis = analyze_csv_file(csv_path) content_list.append({"type": "text", "text": csv_analysis}) for txt_path in txt_files: txt_analysis = analyze_txt_file(txt_path) content_list.append({"type": "text", "text": txt_analysis}) for pdf_path in pdf_files: pdf_markdown = pdf_to_markdown(pdf_path) content_list.append({"type": "text", "text": pdf_markdown}) if video_files: video_content, video_temp_files = process_video(video_files[0]) content_list += video_content temp_files.extend(video_temp_files) return content_list, temp_files if "" in message["text"] and image_files: interleaved_content = process_interleaved_images({"text": message["text"], "files": image_files}) if content_list and content_list[0]["type"] == "text": content_list = content_list[1:] return interleaved_content + content_list, temp_files else: for img_path in image_files: content_list.append({"type": "image", "url": img_path}) return content_list, temp_files ############################################################################## # history -> LLM message conversion ############################################################################## def process_history(history: list[dict]) -> list[dict]: messages = [] current_user_content: list[dict] = [] for item in history: if item["role"] == "assistant": if current_user_content: messages.append({"role": "user", "content": current_user_content}) current_user_content = [] messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]}) else: content = item["content"] if isinstance(content, str): current_user_content.append({"type": "text", "text": content}) elif isinstance(content, list) and len(content) > 0: file_path = content[0] if is_image_file(file_path): current_user_content.append({"type": "image", "url": file_path}) else: current_user_content.append({"type": "text", "text": f"[File: {os.path.basename(file_path)}]"}) if current_user_content: messages.append({"role": "user", "content": current_user_content}) return messages ############################################################################## # Model generation function with OOM catch ############################################################################## def _model_gen_with_oom_catch(**kwargs): """ Catch OutOfMemoryError in separate thread """ try: model.generate(**kwargs) except torch.cuda.OutOfMemoryError: raise RuntimeError( "[OutOfMemoryError] GPU memory insufficient. " "Please reduce Max New Tokens or prompt length." ) finally: # Clear cache after generation clear_cuda_cache() ############################################################################## # Main inference function (with auto web search) ############################################################################## @spaces.GPU(duration=120) def run( message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512, use_web_search: bool = False, web_search_query: str = "", ) -> Iterator[str]: if not validate_media_constraints(message, history): yield "" return temp_files = [] # For tracking temp files try: combined_system_msg = "" # Used internally only (hidden from UI) if system_prompt.strip(): combined_system_msg += f"[System Prompt]\n{system_prompt.strip()}\n\n" if use_web_search: user_text = message["text"] ws_query = extract_keywords(user_text, top_k=5) if ws_query.strip(): logger.info(f"[Auto WebSearch Keyword] {ws_query!r}") ws_result = do_web_search(ws_query) combined_system_msg += f"[X-RAY Security Reference Data]\n{ws_result}\n\n" else: combined_system_msg += "[No valid keywords found, skipping WebSearch]\n\n" messages = [] if combined_system_msg.strip(): messages.append({ "role": "system", "content": [{"type": "text", "text": combined_system_msg.strip()}], }) messages.extend(process_history(history)) user_content, user_temp_files = process_new_user_message(message) temp_files.extend(user_temp_files) # Track temp files for item in user_content: if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS: item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..." messages.append({"role": "user", "content": user_content}) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(device=model.device, dtype=torch.bfloat16) # Limit input token count if inputs.input_ids.shape[1] > MAX_INPUT_LENGTH: inputs.input_ids = inputs.input_ids[:, -MAX_INPUT_LENGTH:] if 'attention_mask' in inputs: inputs.attention_mask = inputs.attention_mask[:, -MAX_INPUT_LENGTH:] streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, ) t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs) t.start() output = "" for new_text in streamer: output += new_text yield output except Exception as e: logger.error(f"Error in run: {str(e)}") yield f"Error occurred: {str(e)}" finally: # Delete temp files for temp_file in temp_files: try: if os.path.exists(temp_file): os.unlink(temp_file) logger.info(f"Deleted temp file: {temp_file}") except Exception as e: logger.warning(f"Failed to delete temp file {temp_file}: {e}") # Explicit memory cleanup try: del inputs, streamer except: pass clear_cuda_cache() ############################################################################## # X-RAY security scanning examples ############################################################################## examples = [ [ { "text": "Analyze this X-RAY image for any prohibited items or security threats. Identify all weapons, explosives, batteries, sharp objects, and liquids over 100ml.", "files": ["assets/additional-examples/beam1.png"], } ], [ { "text": "Perform a comprehensive security scan on this luggage X-RAY. List all detected threats with severity levels (HIGH/MEDIUM/LOW).", "files": ["assets/additional-examples/beam2.png"], } ], [ { "text": "Compare these two X-RAY scans. Which one contains more security threats? Provide detailed analysis of prohibited items in each.", "files": ["assets/additional-examples/beam1.png", "assets/additional-examples/beam2.png"], } ], [ { "text": "Is this bag safe for air travel? Check for: guns, knives, bombs, batteries, scissors, springs, and containers over 100ml.", "files": ["assets/additional-examples/beam1.png"], } ], [ { "text": "Security checkpoint analysis: Identify any EOD (Explosive Ordnance Disposal) related items or components that could be assembled into weapons.", "files": ["assets/additional-examples/beam2.png"], } ], [ { "text": "Quick scan for immediate threats: Focus on firearms, bladed weapons, and explosive materials only.", "files": ["assets/additional-examples/beam1.png"], } ], [ { "text": "Detailed inspection required: Check for concealed weapons, electronic devices with large batteries, and any suspicious dense materials.", "files": ["assets/additional-examples/beam2.png"], } ], [ { "text": "Training mode: Identify and explain why each detected item is considered a security threat according to TSA/aviation security standards.", "files": ["assets/additional-examples/beam1.png"], } ], [ { "text": "Border security check: Scan for contraband, weapons, and any items that violate international travel regulations.", "files": ["assets/additional-examples/beam2.png"], } ], [ { "text": "Emergency protocol: Priority scan for immediate threats - explosives, firearms, and large bladed weapons only. Report findings urgently.", "files": ["assets/additional-examples/beam1.png"], } ], ] ############################################################################## # Gradio UI (Blocks) 구성 ############################################################################## css = """ .gradio-container { background: white; padding: 30px 40px; margin: 20px auto; width: 100% !important; max-width: none !important; } .fillable { width: 100% !important; max-width: 100% !important; } body { background: white; margin: 0; padding: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; color: #333; } button, .btn { background: transparent !important; border: 1px solid #ddd; color: #333; padding: 12px 24px; text-transform: uppercase; font-weight: bold; letter-spacing: 1px; cursor: pointer; } button:hover, .btn:hover { background: rgba(0, 0, 0, 0.05) !important; } h1, h2, h3 { color: #333; } .multimodal-textbox, textarea, input { background: rgba(255, 255, 255, 0.5) !important; border: 1px solid #ddd; color: #333; } .chatbox, .chatbot, .message { background: transparent !important; } #examples_container, .examples-container { margin: auto; width: 90%; background: transparent !important; } """ title_html = """

Gemma-3-R1984-4B-BEAM

""" with gr.Blocks(css=css, title="Gemma-3-R1984-4B-BEAM - X-RAY Security Scanner") as demo: gr.Markdown(title_html) # Display the web search option (while the system prompt and token slider remain hidden) web_search_checkbox = gr.Checkbox( label="Deep Research", value=False ) # X-RAY security scanning system prompt system_prompt_box = gr.Textbox( lines=3, value="""You are an advanced X-RAY security scanning AI specialized in threat detection and aviation security. Your primary mission is to identify ALL potential security threats in X-RAY images with extreme precision. DETECTION PRIORITIES: 1. WEAPONS: Firearms (guns, pistols, rifles), knives, blades, sharp objects, martial arts weapons 2. EXPLOSIVES: Bombs, detonators, explosive materials, suspicious electronics, wires with batteries 3. PROHIBITED ITEMS: Scissors, large batteries, springs (potential weapon components), tools 4. LIQUIDS: Any container over 100ml (potential chemical threats) 5. EOD COMPONENTS: Any items that could be assembled into explosive devices ANALYSIS PROTOCOL: - Scan systematically from top-left to bottom-right - Report location of threats using grid references (e.g., "upper-left quadrant") - Classify threat severity: HIGH (immediate danger), MEDIUM (prohibited), LOW (requires inspection) - Use professional security terminology - Provide recommended actions for each threat CRITICAL: Never miss a potential threat. When in doubt, flag for manual inspection.""", visible=False # hidden from view ) max_tokens_slider = gr.Slider( label="Max New Tokens", minimum=100, maximum=8000, step=50, value=1000, visible=False # hidden from view ) web_search_text = gr.Textbox( lines=1, label="Web Search Query", placeholder="", visible=False # hidden from view ) # Configure the chat interface chat = gr.ChatInterface( fn=run, type="messages", chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]), textbox=gr.MultimodalTextbox( file_types=[ ".webp", ".png", ".jpg", ".jpeg", ".gif", ".mp4", ".csv", ".txt", ".pdf" ], file_count="multiple", autofocus=True ), multimodal=True, additional_inputs=[ system_prompt_box, max_tokens_slider, web_search_checkbox, web_search_text, ], stop_btn=False, title='https://discord.gg/openfreeai', examples=examples, run_examples_on_click=False, cache_examples=False, css_paths=None, delete_cache=(1800, 1800), ) # Example section - since examples are already set in ChatInterface, this is for display only with gr.Row(elem_id="examples_row"): with gr.Column(scale=12, elem_id="examples_container"): pass if __name__ == "__main__": # Run locally demo.launch()