Spaces:
Sleeping
Sleeping
import os | |
import asyncio | |
import torch | |
from datetime import datetime | |
import gradio as gr | |
from typing import Dict, List, Union, Optional | |
import logging | |
import traceback | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class ContentAnalyzer: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = None | |
self.tokenizer = None | |
logger.info(f"Initialized analyzer with device: {self.device}") | |
async def load_model(self, progress=None) -> None: | |
"""Load quantized model with optimized configuration.""" | |
try: | |
if progress: | |
progress(0.1, "Loading tokenizer...") | |
# Quantization configuration | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_quant_type="nf4" | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", | |
use_fast=True | |
) | |
if progress: | |
progress(0.3, "Loading quantized model...") | |
self.model = AutoModelForCausalLM.from_pretrained( | |
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", | |
quantization_config=quantization_config, | |
device_map="auto" | |
) | |
if progress: | |
progress(0.5, "Model loaded successfully") | |
except Exception as e: | |
logger.error(f"Model loading error: {str(e)}") | |
traceback.print_exc() | |
raise | |
def _semantic_chunk_text(self, text: str, max_chunk_size: int = 4096) -> List[str]: | |
"""Semantic chunking with dynamic sizing.""" | |
chunks = [] | |
current_chunk = "" | |
for sentence in text.split('.'): | |
if len(current_chunk) + len(sentence) < max_chunk_size: | |
current_chunk += sentence + '.' | |
else: | |
chunks.append(current_chunk.strip()) | |
current_chunk = sentence + '.' | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
return chunks | |
async def analyze_chunk( | |
self, | |
chunk: str, | |
progress: Optional[gr.Progress] = None | |
) -> List[str]: | |
"""Optimized single-pass chunk analysis.""" | |
categories = [ | |
"Violence", "Death", "Substance Use", "Gore", | |
"Vomit", "Sexual Content", "Sexual Abuse", | |
"Self-Harm", "Gun Use", "Animal Cruelty", | |
"Mental Health Issues" | |
] | |
prompt = f"""Analyze this text for sensitive content. | |
Categories: {', '.join(categories)} | |
Identify ALL present categories. | |
Be precise and direct. | |
Chunk: {chunk} | |
Output Format: Comma-separated category names if present.""" | |
try: | |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=128, | |
do_sample=True, | |
temperature=0.2, | |
top_p=0.9, | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract detected categories | |
detected = [ | |
cat for cat in categories | |
if cat.upper() in response.upper() | |
] | |
return detected | |
except Exception as e: | |
logger.error(f"Chunk analysis error: {str(e)}") | |
return [] | |
async def analyze_script(self, script: str, progress: Optional[gr.Progress] = None) -> List[str]: | |
if not self.model or not self.tokenizer: | |
await self.load_model(progress) | |
chunks = self._semantic_chunk_text(script) | |
# Concurrent chunk processing | |
tasks = [self.analyze_chunk(chunk) for chunk in chunks] | |
chunk_results = await asyncio.gather(*tasks) | |
# Flatten and deduplicate results | |
identified_triggers = set( | |
trigger | |
for chunk_triggers in chunk_results | |
for trigger in chunk_triggers | |
) | |
return list(identified_triggers) or ["None"] | |
async def analyze_content( | |
script: str, | |
progress: Optional[gr.Progress] = None | |
) -> Dict[str, Union[List[str], str]]: | |
analyzer = ContentAnalyzer() | |
try: | |
triggers = await analyzer.analyze_script(script, progress) | |
result = { | |
"detected_triggers": triggers, | |
"confidence": "High - Content detected" if triggers != ["None"] else "High - No concerning content detected", | |
"model": "DeepSeek-R1-Distill-Qwen-1.5B", | |
"analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
} | |
return result | |
except Exception as e: | |
logger.error(f"Analysis error: {str(e)}") | |
return { | |
"detected_triggers": ["Error occurred during analysis"], | |
"confidence": "Error", | |
"model": "DeepSeek-R1-Distill-Qwen-1.5B", | |
"analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
"error": str(e) | |
} | |
if __name__ == "__main__": | |
iface = gr.Interface( | |
fn=analyze_content, | |
inputs=gr.Textbox(lines=8, label="Input Text"), | |
outputs=gr.JSON(), | |
title="Content Sensitivity Analysis", | |
description="Analyze text content for sensitive topics using DeepSeek R1" | |
) | |
iface.launch() |