Syllabus-Formatter / model /analyzer.py
Kuberwastaken's picture
Increased Model Efficiency
2b6c9d9
raw
history blame
5.78 kB
import os
import asyncio
import torch
from datetime import datetime
import gradio as gr
from typing import Dict, List, Union, Optional
import logging
import traceback
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ContentAnalyzer:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = None
self.tokenizer = None
logger.info(f"Initialized analyzer with device: {self.device}")
async def load_model(self, progress=None) -> None:
"""Load quantized model with optimized configuration."""
try:
if progress:
progress(0.1, "Loading tokenizer...")
# Quantization configuration
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4"
)
self.tokenizer = AutoTokenizer.from_pretrained(
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
use_fast=True
)
if progress:
progress(0.3, "Loading quantized model...")
self.model = AutoModelForCausalLM.from_pretrained(
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
quantization_config=quantization_config,
device_map="auto"
)
if progress:
progress(0.5, "Model loaded successfully")
except Exception as e:
logger.error(f"Model loading error: {str(e)}")
traceback.print_exc()
raise
def _semantic_chunk_text(self, text: str, max_chunk_size: int = 4096) -> List[str]:
"""Semantic chunking with dynamic sizing."""
chunks = []
current_chunk = ""
for sentence in text.split('.'):
if len(current_chunk) + len(sentence) < max_chunk_size:
current_chunk += sentence + '.'
else:
chunks.append(current_chunk.strip())
current_chunk = sentence + '.'
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
async def analyze_chunk(
self,
chunk: str,
progress: Optional[gr.Progress] = None
) -> List[str]:
"""Optimized single-pass chunk analysis."""
categories = [
"Violence", "Death", "Substance Use", "Gore",
"Vomit", "Sexual Content", "Sexual Abuse",
"Self-Harm", "Gun Use", "Animal Cruelty",
"Mental Health Issues"
]
prompt = f"""Analyze this text for sensitive content.
Categories: {', '.join(categories)}
Identify ALL present categories.
Be precise and direct.
Chunk: {chunk}
Output Format: Comma-separated category names if present."""
try:
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model.generate(
**inputs,
max_new_tokens=128,
do_sample=True,
temperature=0.2,
top_p=0.9,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract detected categories
detected = [
cat for cat in categories
if cat.upper() in response.upper()
]
return detected
except Exception as e:
logger.error(f"Chunk analysis error: {str(e)}")
return []
async def analyze_script(self, script: str, progress: Optional[gr.Progress] = None) -> List[str]:
if not self.model or not self.tokenizer:
await self.load_model(progress)
chunks = self._semantic_chunk_text(script)
# Concurrent chunk processing
tasks = [self.analyze_chunk(chunk) for chunk in chunks]
chunk_results = await asyncio.gather(*tasks)
# Flatten and deduplicate results
identified_triggers = set(
trigger
for chunk_triggers in chunk_results
for trigger in chunk_triggers
)
return list(identified_triggers) or ["None"]
async def analyze_content(
script: str,
progress: Optional[gr.Progress] = None
) -> Dict[str, Union[List[str], str]]:
analyzer = ContentAnalyzer()
try:
triggers = await analyzer.analyze_script(script, progress)
result = {
"detected_triggers": triggers,
"confidence": "High - Content detected" if triggers != ["None"] else "High - No concerning content detected",
"model": "DeepSeek-R1-Distill-Qwen-1.5B",
"analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
return result
except Exception as e:
logger.error(f"Analysis error: {str(e)}")
return {
"detected_triggers": ["Error occurred during analysis"],
"confidence": "Error",
"model": "DeepSeek-R1-Distill-Qwen-1.5B",
"analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"error": str(e)
}
if __name__ == "__main__":
iface = gr.Interface(
fn=analyze_content,
inputs=gr.Textbox(lines=8, label="Input Text"),
outputs=gr.JSON(),
title="Content Sensitivity Analysis",
description="Analyze text content for sensitive topics using DeepSeek R1"
)
iface.launch()