Kuberwastaken commited on
Commit
cba901f
·
1 Parent(s): cad764a

Potentially more effecient model

Browse files
Files changed (1) hide show
  1. model/analyzer.py +66 -101
model/analyzer.py CHANGED
@@ -1,13 +1,13 @@
1
  import os
2
- import asyncio
3
  import torch
4
  from datetime import datetime
5
  import gradio as gr
6
  from typing import Dict, List, Union, Optional
7
  import logging
8
- import traceback
9
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
10
 
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
@@ -16,143 +16,108 @@ class ContentAnalyzer:
16
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
  self.model = None
18
  self.tokenizer = None
 
 
 
 
 
 
 
19
  logger.info(f"Initialized analyzer with device: {self.device}")
 
20
 
21
- async def load_model(self, progress=None) -> None:
22
- """Load quantized model with optimized configuration."""
23
  try:
24
- if progress:
25
- progress(0.1, "Loading tokenizer...")
26
-
27
- # Quantization configuration
28
- quantization_config = BitsAndBytesConfig(
29
- load_in_4bit=True,
30
- bnb_4bit_compute_dtype=torch.float16,
31
- bnb_4bit_quant_type="nf4"
32
- )
33
-
34
  self.tokenizer = AutoTokenizer.from_pretrained(
35
  "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
36
- use_fast=True
 
37
  )
38
-
39
- if progress:
40
- progress(0.3, "Loading quantized model...")
41
-
42
  self.model = AutoModelForCausalLM.from_pretrained(
43
  "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
44
- quantization_config=quantization_config,
45
- device_map="auto"
46
- )
47
-
48
- if progress:
49
- progress(0.5, "Model loaded successfully")
50
-
51
  except Exception as e:
52
- logger.error(f"Model loading error: {str(e)}")
53
- traceback.print_exc()
54
  raise
55
 
56
- def _semantic_chunk_text(self, text: str, max_chunk_size: int = 4096) -> List[str]:
57
- """Semantic chunking with dynamic sizing."""
 
58
  chunks = []
59
  current_chunk = ""
60
- for sentence in text.split('.'):
61
- if len(current_chunk) + len(sentence) < max_chunk_size:
62
- current_chunk += sentence + '.'
 
63
  else:
64
- chunks.append(current_chunk.strip())
65
- current_chunk = sentence + '.'
 
66
 
67
  if current_chunk:
68
  chunks.append(current_chunk.strip())
69
-
 
70
  return chunks
71
 
72
- async def analyze_chunk(
73
- self,
74
- chunk: str,
75
- progress: Optional[gr.Progress] = None
76
- ) -> List[str]:
77
- """Optimized single-pass chunk analysis."""
78
- categories = [
79
- "Violence", "Death", "Substance Use", "Gore",
80
- "Vomit", "Sexual Content", "Sexual Abuse",
81
- "Self-Harm", "Gun Use", "Animal Cruelty",
82
- "Mental Health Issues"
83
- ]
84
-
85
- prompt = f"""Analyze this text for sensitive content.
86
- Categories: {', '.join(categories)}
87
- Identify ALL present categories.
88
- Be precise and direct.
89
- Chunk: {chunk}
90
- Output Format: Comma-separated category names if present."""
91
-
92
- try:
93
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
94
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
95
 
 
 
 
 
 
 
96
  outputs = self.model.generate(
97
  **inputs,
98
- max_new_tokens=128,
99
- do_sample=True,
100
- temperature=0.2,
101
- top_p=0.9,
102
  pad_token_id=self.tokenizer.eos_token_id
103
  )
104
-
105
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
106
-
107
- # Extract detected categories
108
- detected = [
109
- cat for cat in categories
110
- if cat.upper() in response.upper()
111
- ]
112
-
113
- return detected
114
-
115
- except Exception as e:
116
- logger.error(f"Chunk analysis error: {str(e)}")
117
- return []
118
 
119
  async def analyze_script(self, script: str, progress: Optional[gr.Progress] = None) -> List[str]:
120
- if not self.model or not self.tokenizer:
121
- await self.load_model(progress)
122
-
123
- chunks = self._semantic_chunk_text(script)
124
 
125
- # Concurrent chunk processing
126
- tasks = [self.analyze_chunk(chunk) for chunk in chunks]
127
- chunk_results = await asyncio.gather(*tasks)
128
-
129
- # Flatten and deduplicate results
130
- identified_triggers = set(
131
- trigger
132
- for chunk_triggers in chunk_results
133
- for trigger in chunk_triggers
134
- )
135
-
136
- return list(identified_triggers) or ["None"]
137
 
138
  async def analyze_content(
139
  script: str,
140
  progress: Optional[gr.Progress] = None
141
  ) -> Dict[str, Union[List[str], str]]:
142
- analyzer = ContentAnalyzer()
143
-
144
  try:
 
145
  triggers = await analyzer.analyze_script(script, progress)
146
-
147
- result = {
148
  "detected_triggers": triggers,
149
  "confidence": "High - Content detected" if triggers != ["None"] else "High - No concerning content detected",
150
  "model": "DeepSeek-R1-Distill-Qwen-1.5B",
151
  "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
152
  }
153
-
154
- return result
155
-
156
  except Exception as e:
157
  logger.error(f"Analysis error: {str(e)}")
158
  return {
 
1
  import os
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  from datetime import datetime
5
  import gradio as gr
6
  from typing import Dict, List, Union, Optional
7
  import logging
8
+ import re
 
9
 
10
+ # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
 
16
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
  self.model = None
18
  self.tokenizer = None
19
+ self.categories = [
20
+ "Violence", "Death", "Substance Use", "Gore",
21
+ "Vomit", "Sexual Content", "Sexual Abuse",
22
+ "Self-Harm", "Gun Use", "Animal Cruelty",
23
+ "Mental Health Issues"
24
+ ]
25
+ self.pattern = re.compile(r'\b(' + '|'.join(self.categories) + r')\b', re.IGNORECASE)
26
  logger.info(f"Initialized analyzer with device: {self.device}")
27
+ self._load_model() # Load model during initialization
28
 
29
+ def _load_model(self) -> None:
30
+ """Load model and tokenizer synchronously during initialization"""
31
  try:
32
+ logger.info("Loading model components...")
 
 
 
 
 
 
 
 
 
33
  self.tokenizer = AutoTokenizer.from_pretrained(
34
  "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
35
+ use_fast=True,
36
+ truncation_side="left"
37
  )
 
 
 
 
38
  self.model = AutoModelForCausalLM.from_pretrained(
39
  "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
40
+ torch_dtype=torch.float32,
41
+ low_cpu_mem_usage=True
42
+ ).to(self.device).eval()
43
+ logger.info("Model loaded successfully")
 
 
 
44
  except Exception as e:
45
+ logger.error(f"Model loading failed: {str(e)}")
 
46
  raise
47
 
48
+ def _chunk_text(self, text: str, chunk_size: int = 1024) -> List[str]:
49
+ """Optimized chunking using paragraph boundaries"""
50
+ paragraphs = text.split('\n\n')
51
  chunks = []
52
  current_chunk = ""
53
+
54
+ for para in paragraphs:
55
+ if len(current_chunk) + len(para) < chunk_size:
56
+ current_chunk += para + "\n\n"
57
  else:
58
+ if current_chunk:
59
+ chunks.append(current_chunk.strip())
60
+ current_chunk = para + "\n\n"
61
 
62
  if current_chunk:
63
  chunks.append(current_chunk.strip())
64
+
65
+ logger.info(f"Split text into {len(chunks)} chunks")
66
  return chunks
67
 
68
+ async def _analyze_chunk(self, chunk: str) -> List[str]:
69
+ """Optimized chunk analysis with structured prompt"""
70
+ prompt = f"""You are a highly specialized content analysis AI, Analyze this text for sensitive content from: {', '.join(self.categories)}.
71
+ Respond with categories in format: [CATEGORIES]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ Text: {chunk[:2000]}
74
+ [CATEGORIES]: """
75
+
76
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(self.device)
77
+
78
+ with torch.no_grad():
79
  outputs = self.model.generate(
80
  **inputs,
81
+ max_new_tokens=50,
82
+ do_sample=False,
 
 
83
  pad_token_id=self.tokenizer.eos_token_id
84
  )
85
+
86
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
87
+ return [m.capitalize() for m in self.pattern.findall(response)]
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  async def analyze_script(self, script: str, progress: Optional[gr.Progress] = None) -> List[str]:
90
+ """Main analysis method with progress support"""
91
+ identified_triggers = set()
92
+ chunks = self._chunk_text(script)
 
93
 
94
+ for idx, chunk in enumerate(chunks):
95
+ if progress:
96
+ progress((idx/len(chunks), f"Analyzing chunk {idx+1}/{len(chunks)}"))
97
+
98
+ triggers = await self._analyze_chunk(chunk)
99
+ identified_triggers.update(triggers)
100
+
101
+ if progress:
102
+ progress((1.0, "Analysis complete"))
103
+
104
+ return sorted(identified_triggers) if identified_triggers else ["None"]
 
105
 
106
  async def analyze_content(
107
  script: str,
108
  progress: Optional[gr.Progress] = None
109
  ) -> Dict[str, Union[List[str], str]]:
110
+ """Main analysis function for Gradio interface"""
 
111
  try:
112
+ analyzer = ContentAnalyzer()
113
  triggers = await analyzer.analyze_script(script, progress)
114
+
115
+ return {
116
  "detected_triggers": triggers,
117
  "confidence": "High - Content detected" if triggers != ["None"] else "High - No concerning content detected",
118
  "model": "DeepSeek-R1-Distill-Qwen-1.5B",
119
  "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
120
  }
 
 
 
121
  except Exception as e:
122
  logger.error(f"Analysis error: {str(e)}")
123
  return {