import os from transformers import AutoTokenizer, AutoModelForCausalLM import torch from datetime import datetime import gradio as gr from typing import Dict, List, Union, Optional import logging import traceback # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ContentAnalyzer: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = None self.tokenizer = None logger.info(f"Initialized analyzer with device: {self.device}") async def load_model(self, progress=None) -> None: """Load the model and tokenizer with progress updates and detailed logging.""" try: print("\n=== Starting Model Loading ===") print(f"Time: {datetime.now()}") if progress: progress(0.1, "Loading tokenizer...") print("Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", use_fast=True ) if progress: progress(0.3, "Loading model...") print(f"Loading model on {self.device}...") self.model = AutoModelForCausalLM.from_pretrained( "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map="auto" ) if progress: progress(0.5, "Model loaded successfully") print("Model and tokenizer loaded successfully") logger.info(f"Model loaded successfully on {self.device}") except Exception as e: logger.error(f"Error loading model: {str(e)}") print(f"\nERROR DURING MODEL LOADING: {str(e)}") print("Stack trace:") traceback.print_exc() raise def _chunk_text(self, text: str, chunk_size: int = 2048, overlap: int = 256) -> List[str]: """Split text into overlapping chunks for processing.""" chunks = [] for i in range(0, len(text), chunk_size - overlap): chunk = text[i:i + chunk_size] chunks.append(chunk) print(f"Split text into {len(chunks)} chunks with {overlap} token overlap") return chunks async def analyze_chunk( self, chunk: str, progress: Optional[gr.Progress] = None, current_progress: float = 0, progress_step: float = 0 ) -> List[str]: """Analyze a single chunk of text for triggers with detailed logging.""" print(f"\n--- Processing Chunk ---") print(f"Chunk text (preview): {chunk[:50]}...") # Comprehensive trigger categories categories = [ "Violence", "Death", "Substance Use", "Gore", "Vomit", "Sexual Content", "Sexual Abuse", "Self-Harm", "Gun Use", "Animal Cruelty", "Mental Health Issues" ] # Comprehensive prompt for single-pass analysis prompt = f"""Comprehensive Content Sensitivity Analysis Carefully analyze the following text for sensitive content categories: {', '.join(categories)} Detailed Requirements: 1. Thoroughly examine entire text chunk 2. Identify presence of ANY of these categories 3. Provide clear, objective assessment 4. Minimal subjective interpretation TEXT CHUNK: {chunk} RESPONSE FORMAT: - List categories DEFINITIVELY present - Brief objective justification for each - Strict YES/NO categorization""" try: print("Sending prompt to model...") inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): print("Generating response...") outputs = self.model.generate( **inputs, max_new_tokens=256, do_sample=True, temperature=0.2, top_p=0.9, pad_token_id=self.tokenizer.eos_token_id ) response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() print("Full Model Response:", response_text) # Parse detected triggers detected_triggers = [] for category in categories: if category.upper() in response_text.upper(): detected_triggers.append(category) print(f"Detected triggers in chunk: {detected_triggers}") if progress: current_progress += progress_step progress(min(current_progress, 0.9), "Analyzing chunk...") return detected_triggers except Exception as e: logger.error(f"Error analyzing chunk: {str(e)}") print(f"Error during chunk analysis: {str(e)}") traceback.print_exc() return [] async def analyze_script(self, script: str, progress: Optional[gr.Progress] = None) -> List[str]: """Analyze the entire script for triggers with progress updates.""" print("\n=== Starting Script Analysis ===") print(f"Time: {datetime.now()}") if not self.model or not self.tokenizer: await self.load_model(progress) chunks = self._chunk_text(script) identified_triggers = set() progress_step = 0.4 / len(chunks) current_progress = 0.5 # Starting after model loading for chunk_idx, chunk in enumerate(chunks, 1): chunk_triggers = await self.analyze_chunk( chunk, progress, current_progress, progress_step ) identified_triggers.update(chunk_triggers) if progress: progress(0.95, "Finalizing results...") final_triggers = list(identified_triggers) print("\n=== Analysis Complete ===") print("Final Results:", final_triggers) return final_triggers if final_triggers else ["None"] async def analyze_content( script: str, progress: Optional[gr.Progress] = None ) -> Dict[str, Union[List[str], str]]: """Main analysis function for the Gradio interface.""" print("\n=== Starting Content Analysis ===") print(f"Time: {datetime.now()}") analyzer = ContentAnalyzer() try: triggers = await analyzer.analyze_script(script, progress) if progress: progress(1.0, "Analysis complete!") result = { "detected_triggers": triggers, "confidence": "High - Content detected" if triggers != ["None"] else "High - No concerning content detected", "model": "DeepSeek-R1-Distill-Qwen-1.5B", "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") } print("\nFinal Result Dictionary:", result) return result except Exception as e: logger.error(f"Analysis error: {str(e)}") print(f"\nERROR OCCURRED: {str(e)}") print("Stack trace:") traceback.print_exc() return { "detected_triggers": ["Error occurred during analysis"], "confidence": "Error", "model": "DeepSeek-R1-Distill-Qwen-1.5B", "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "error": str(e) } if __name__ == "__main__": # Gradio interface iface = gr.Interface( fn=analyze_content, inputs=gr.Textbox(lines=8, label="Input Text"), outputs=gr.JSON(), title="Content Sensitivity Analysis", description="Analyze text content for sensitive topics using DeepSeek R1" ) iface.launch()