Spaces:
Sleeping
Sleeping
import os | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
from datetime import datetime | |
import gradio as gr | |
from typing import Dict, List, Union, Optional | |
import logging | |
import traceback | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class ContentAnalyzer: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = None | |
self.tokenizer = None | |
logger.info(f"Initialized analyzer with device: {self.device}") | |
async def load_model(self, progress=None) -> None: | |
"""Load the model and tokenizer with progress updates and detailed logging.""" | |
try: | |
print("\n=== Starting Model Loading ===") | |
print(f"Time: {datetime.now()}") | |
if progress: | |
progress(0.1, "Loading tokenizer...") | |
print("Loading tokenizer...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", | |
use_fast=True | |
) | |
if progress: | |
progress(0.3, "Loading model...") | |
print(f"Loading model on {self.device}...") | |
self.model = AutoModelForCausalLM.from_pretrained( | |
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", | |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
device_map="auto" | |
) | |
if progress: | |
progress(0.5, "Model loaded successfully") | |
print("Model and tokenizer loaded successfully") | |
logger.info(f"Model loaded successfully on {self.device}") | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
print(f"\nERROR DURING MODEL LOADING: {str(e)}") | |
print("Stack trace:") | |
traceback.print_exc() | |
raise | |
def _chunk_text(self, text: str, chunk_size: int = 2048, overlap: int = 256) -> List[str]: | |
"""Split text into overlapping chunks for processing.""" | |
chunks = [] | |
for i in range(0, len(text), chunk_size - overlap): | |
chunk = text[i:i + chunk_size] | |
chunks.append(chunk) | |
print(f"Split text into {len(chunks)} chunks with {overlap} token overlap") | |
return chunks | |
async def analyze_chunk( | |
self, | |
chunk: str, | |
progress: Optional[gr.Progress] = None, | |
current_progress: float = 0, | |
progress_step: float = 0 | |
) -> List[str]: | |
"""Analyze a single chunk of text for triggers with detailed logging.""" | |
print(f"\n--- Processing Chunk ---") | |
print(f"Chunk text (preview): {chunk[:50]}...") | |
# Comprehensive trigger categories | |
categories = [ | |
"Violence", "Death", "Substance Use", "Gore", | |
"Vomit", "Sexual Content", "Sexual Abuse", | |
"Self-Harm", "Gun Use", "Animal Cruelty", | |
"Mental Health Issues" | |
] | |
# Comprehensive prompt for single-pass analysis | |
prompt = f"""Comprehensive Content Sensitivity Analysis | |
Carefully analyze the following text for sensitive content categories: | |
{', '.join(categories)} | |
Detailed Requirements: | |
1. Thoroughly examine entire text chunk | |
2. Identify presence of ANY of these categories | |
3. Provide clear, objective assessment | |
4. Minimal subjective interpretation | |
TEXT CHUNK: | |
{chunk} | |
RESPONSE FORMAT: | |
- List categories DEFINITIVELY present | |
- Brief objective justification for each | |
- Strict YES/NO categorization""" | |
try: | |
print("Sending prompt to model...") | |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
print("Generating response...") | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=256, | |
do_sample=True, | |
temperature=0.2, | |
top_p=0.9, | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
print("Full Model Response:", response_text) | |
# Parse detected triggers | |
detected_triggers = [] | |
for category in categories: | |
if category.upper() in response_text.upper(): | |
detected_triggers.append(category) | |
print(f"Detected triggers in chunk: {detected_triggers}") | |
if progress: | |
current_progress += progress_step | |
progress(min(current_progress, 0.9), "Analyzing chunk...") | |
return detected_triggers | |
except Exception as e: | |
logger.error(f"Error analyzing chunk: {str(e)}") | |
print(f"Error during chunk analysis: {str(e)}") | |
traceback.print_exc() | |
return [] | |
async def analyze_script(self, script: str, progress: Optional[gr.Progress] = None) -> List[str]: | |
"""Analyze the entire script for triggers with progress updates.""" | |
print("\n=== Starting Script Analysis ===") | |
print(f"Time: {datetime.now()}") | |
if not self.model or not self.tokenizer: | |
await self.load_model(progress) | |
chunks = self._chunk_text(script) | |
identified_triggers = set() | |
progress_step = 0.4 / len(chunks) | |
current_progress = 0.5 # Starting after model loading | |
for chunk_idx, chunk in enumerate(chunks, 1): | |
chunk_triggers = await self.analyze_chunk( | |
chunk, | |
progress, | |
current_progress, | |
progress_step | |
) | |
identified_triggers.update(chunk_triggers) | |
if progress: | |
progress(0.95, "Finalizing results...") | |
final_triggers = list(identified_triggers) | |
print("\n=== Analysis Complete ===") | |
print("Final Results:", final_triggers) | |
return final_triggers if final_triggers else ["None"] | |
async def analyze_content( | |
script: str, | |
progress: Optional[gr.Progress] = None | |
) -> Dict[str, Union[List[str], str]]: | |
"""Main analysis function for the Gradio interface.""" | |
print("\n=== Starting Content Analysis ===") | |
print(f"Time: {datetime.now()}") | |
analyzer = ContentAnalyzer() | |
try: | |
triggers = await analyzer.analyze_script(script, progress) | |
if progress: | |
progress(1.0, "Analysis complete!") | |
result = { | |
"detected_triggers": triggers, | |
"confidence": "High - Content detected" if triggers != ["None"] else "High - No concerning content detected", | |
"model": "DeepSeek-R1-Distill-Qwen-1.5B", | |
"analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
} | |
print("\nFinal Result Dictionary:", result) | |
return result | |
except Exception as e: | |
logger.error(f"Analysis error: {str(e)}") | |
print(f"\nERROR OCCURRED: {str(e)}") | |
print("Stack trace:") | |
traceback.print_exc() | |
return { | |
"detected_triggers": ["Error occurred during analysis"], | |
"confidence": "Error", | |
"model": "DeepSeek-R1-Distill-Qwen-1.5B", | |
"analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
"error": str(e) | |
} | |
if __name__ == "__main__": | |
# Gradio interface | |
iface = gr.Interface( | |
fn=analyze_content, | |
inputs=gr.Textbox(lines=8, label="Input Text"), | |
outputs=gr.JSON(), | |
title="Content Sensitivity Analysis", | |
description="Analyze text content for sensitive topics using DeepSeek R1" | |
) | |
iface.launch() |