|
import os |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
from datetime import datetime |
|
import gradio as gr |
|
from typing import Dict, List, Union, Optional |
|
import logging |
|
import traceback |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class ContentAnalyzer: |
|
def __init__(self): |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model = None |
|
self.tokenizer = None |
|
logger.info(f"Initialized analyzer with device: {self.device}") |
|
|
|
async def load_model(self, progress=None) -> None: |
|
"""Load the model and tokenizer with progress updates and detailed logging.""" |
|
try: |
|
print("\n=== Starting Model Loading ===") |
|
print(f"Time: {datetime.now()}") |
|
|
|
if progress: |
|
progress(0.1, "Loading tokenizer...") |
|
|
|
print("Loading tokenizer...") |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", |
|
use_fast=True |
|
) |
|
|
|
if progress: |
|
progress(0.3, "Loading model...") |
|
|
|
print(f"Loading model on {self.device}...") |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", |
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
|
device_map="auto" |
|
) |
|
|
|
if progress: |
|
progress(0.5, "Model loaded successfully") |
|
|
|
print("Model and tokenizer loaded successfully") |
|
logger.info(f"Model loaded successfully on {self.device}") |
|
except Exception as e: |
|
logger.error(f"Error loading model: {str(e)}") |
|
print(f"\nERROR DURING MODEL LOADING: {str(e)}") |
|
print("Stack trace:") |
|
traceback.print_exc() |
|
raise |
|
|
|
def _chunk_text(self, text: str, chunk_size: int = 2048, overlap: int = 256) -> List[str]: |
|
"""Split text into overlapping chunks for processing.""" |
|
chunks = [] |
|
for i in range(0, len(text), chunk_size - overlap): |
|
chunk = text[i:i + chunk_size] |
|
chunks.append(chunk) |
|
print(f"Split text into {len(chunks)} chunks with {overlap} token overlap") |
|
return chunks |
|
|
|
async def analyze_chunk( |
|
self, |
|
chunk: str, |
|
progress: Optional[gr.Progress] = None, |
|
current_progress: float = 0, |
|
progress_step: float = 0 |
|
) -> List[str]: |
|
"""Analyze a single chunk of text for triggers with detailed logging.""" |
|
print(f"\n--- Processing Chunk ---") |
|
print(f"Chunk text (preview): {chunk[:50]}...") |
|
|
|
|
|
categories = [ |
|
"Violence", "Death", "Substance Use", "Gore", |
|
"Vomit", "Sexual Content", "Sexual Abuse", |
|
"Self-Harm", "Gun Use", "Animal Cruelty", |
|
"Mental Health Issues" |
|
] |
|
|
|
|
|
prompt = f"""Comprehensive Content Sensitivity Analysis |
|
|
|
Carefully analyze the following text for sensitive content categories: |
|
{', '.join(categories)} |
|
|
|
Detailed Requirements: |
|
1. Thoroughly examine entire text chunk |
|
2. Identify presence of ANY of these categories |
|
3. Provide clear, objective assessment |
|
4. Minimal subjective interpretation |
|
|
|
TEXT CHUNK: |
|
{chunk} |
|
|
|
RESPONSE FORMAT: |
|
- List categories DEFINITIVELY present |
|
- Brief objective justification for each |
|
- Strict YES/NO categorization""" |
|
|
|
try: |
|
print("Sending prompt to model...") |
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096) |
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
print("Generating response...") |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=256, |
|
do_sample=True, |
|
temperature=0.2, |
|
top_p=0.9, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() |
|
print("Full Model Response:", response_text) |
|
|
|
|
|
detected_triggers = [] |
|
for category in categories: |
|
if category.upper() in response_text.upper(): |
|
detected_triggers.append(category) |
|
|
|
print(f"Detected triggers in chunk: {detected_triggers}") |
|
|
|
if progress: |
|
current_progress += progress_step |
|
progress(min(current_progress, 0.9), "Analyzing chunk...") |
|
|
|
return detected_triggers |
|
|
|
except Exception as e: |
|
logger.error(f"Error analyzing chunk: {str(e)}") |
|
print(f"Error during chunk analysis: {str(e)}") |
|
traceback.print_exc() |
|
return [] |
|
|
|
async def analyze_script(self, script: str, progress: Optional[gr.Progress] = None) -> List[str]: |
|
"""Analyze the entire script for triggers with progress updates.""" |
|
print("\n=== Starting Script Analysis ===") |
|
print(f"Time: {datetime.now()}") |
|
|
|
if not self.model or not self.tokenizer: |
|
await self.load_model(progress) |
|
|
|
chunks = self._chunk_text(script) |
|
identified_triggers = set() |
|
progress_step = 0.4 / len(chunks) |
|
current_progress = 0.5 |
|
|
|
for chunk_idx, chunk in enumerate(chunks, 1): |
|
chunk_triggers = await self.analyze_chunk( |
|
chunk, |
|
progress, |
|
current_progress, |
|
progress_step |
|
) |
|
identified_triggers.update(chunk_triggers) |
|
|
|
if progress: |
|
progress(0.95, "Finalizing results...") |
|
|
|
final_triggers = list(identified_triggers) |
|
print("\n=== Analysis Complete ===") |
|
print("Final Results:", final_triggers) |
|
|
|
return final_triggers if final_triggers else ["None"] |
|
|
|
async def analyze_content( |
|
script: str, |
|
progress: Optional[gr.Progress] = None |
|
) -> Dict[str, Union[List[str], str]]: |
|
"""Main analysis function for the Gradio interface.""" |
|
print("\n=== Starting Content Analysis ===") |
|
print(f"Time: {datetime.now()}") |
|
|
|
analyzer = ContentAnalyzer() |
|
|
|
try: |
|
triggers = await analyzer.analyze_script(script, progress) |
|
|
|
if progress: |
|
progress(1.0, "Analysis complete!") |
|
|
|
result = { |
|
"detected_triggers": triggers, |
|
"confidence": "High - Content detected" if triggers != ["None"] else "High - No concerning content detected", |
|
"model": "DeepSeek-R1-Distill-Qwen-1.5B", |
|
"analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
} |
|
|
|
print("\nFinal Result Dictionary:", result) |
|
return result |
|
|
|
except Exception as e: |
|
logger.error(f"Analysis error: {str(e)}") |
|
print(f"\nERROR OCCURRED: {str(e)}") |
|
print("Stack trace:") |
|
traceback.print_exc() |
|
return { |
|
"detected_triggers": ["Error occurred during analysis"], |
|
"confidence": "Error", |
|
"model": "DeepSeek-R1-Distill-Qwen-1.5B", |
|
"analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
"error": str(e) |
|
} |
|
|
|
if __name__ == "__main__": |
|
|
|
iface = gr.Interface( |
|
fn=analyze_content, |
|
inputs=gr.Textbox(lines=8, label="Input Text"), |
|
outputs=gr.JSON(), |
|
title="Content Sensitivity Analysis", |
|
description="Analyze text content for sensitive topics using DeepSeek R1" |
|
) |
|
iface.launch() |