Cylanoid commited on
Commit
19103d4
·
verified ·
1 Parent(s): 50c6bec

Update document_analyzer.py

Browse files
Files changed (1) hide show
  1. document_analyzer.py +10 -26
document_analyzer.py CHANGED
@@ -1,5 +1,5 @@
1
  # document_analyzer.py
2
- # Enhanced document analysis module for healthcare fraud detection with Llama 4
3
 
4
  import torch
5
  import re
@@ -13,9 +13,9 @@ except LookupError:
13
  nltk.download('punkt')
14
 
15
  class HealthcareFraudAnalyzer:
16
- def __init__(self, model, processor, device=None):
17
  self.model = model
18
- self.processor = processor
19
  self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu"
20
  self.model.to(self.device)
21
  self.model.eval()
@@ -68,32 +68,16 @@ class HealthcareFraudAnalyzer:
68
  return chunks
69
 
70
  def analyze_chunk(self, chunk: str) -> Dict[str, Any]:
71
- messages = [
72
- {
73
- "role": "user",
74
- "content": [
75
- {
76
- "type": "text",
77
- "text": f"""Analyze the following healthcare document text for evidence of fraud, neglect, abuse, or criminal conduct.
78
  Focus on: {', '.join(self.fraud_categories)}.
79
  Provide specific indicators and cite the relevant text.
80
 
81
  DOCUMENT TEXT:
82
  {chunk}
83
 
84
- ANALYSIS:"""
85
- }
86
- ]
87
- }
88
- ]
89
 
90
- inputs = self.processor.apply_chat_template(
91
- messages,
92
- add_generation_prompt=True,
93
- tokenize=True,
94
- return_dict=True,
95
- return_tensors="pt"
96
- ).to(self.device)
97
 
98
  with torch.no_grad():
99
  output = self.model.generate(
@@ -104,8 +88,8 @@ ANALYSIS:"""
104
  repetition_penalty=1.2
105
  )
106
 
107
- response = self.processor.batch_decode(output[:, inputs["input_ids"].shape[-1]:])[0]
108
- analysis = response.strip()
109
 
110
  term_matches = self._find_key_terms(chunk)
111
 
@@ -262,7 +246,7 @@ ANALYSIS:"""
262
 
263
  print("\n" + "="*80 + "\n")
264
 
265
- def analyze_pdf_for_fraud(pdf_path, model, processor):
266
  import pdfplumber
267
 
268
  with pdfplumber.open(pdf_path) as pdf:
@@ -270,7 +254,7 @@ def analyze_pdf_for_fraud(pdf_path, model, processor):
270
  for page in pdf.pages:
271
  text += page.extract_text() or ""
272
 
273
- analyzer = HealthcareFraudAnalyzer(model, processor)
274
  results = analyzer.analyze_document(text)
275
 
276
  analyzer.print_report(results)
 
1
  # document_analyzer.py
2
+ # Enhanced document analysis module for healthcare fraud detection with Llama 4 (text-only)
3
 
4
  import torch
5
  import re
 
13
  nltk.download('punkt')
14
 
15
  class HealthcareFraudAnalyzer:
16
+ def __init__(self, model, tokenizer, device=None):
17
  self.model = model
18
+ self.tokenizer = tokenizer
19
  self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu"
20
  self.model.to(self.device)
21
  self.model.eval()
 
68
  return chunks
69
 
70
  def analyze_chunk(self, chunk: str) -> Dict[str, Any]:
71
+ prompt = f"""<s>[INST] Analyze the following healthcare document text for evidence of fraud, neglect, abuse, or criminal conduct.
 
 
 
 
 
 
72
  Focus on: {', '.join(self.fraud_categories)}.
73
  Provide specific indicators and cite the relevant text.
74
 
75
  DOCUMENT TEXT:
76
  {chunk}
77
 
78
+ ANALYSIS: [/INST]"""
 
 
 
 
79
 
80
+ inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).to(self.device)
 
 
 
 
 
 
81
 
82
  with torch.no_grad():
83
  output = self.model.generate(
 
88
  repetition_penalty=1.2
89
  )
90
 
91
+ response = self.tokenizer.decode(output[0], skip_special_tokens=True)
92
+ analysis = response.split("ANALYSIS:")[-1].strip()
93
 
94
  term_matches = self._find_key_terms(chunk)
95
 
 
246
 
247
  print("\n" + "="*80 + "\n")
248
 
249
+ def analyze_pdf_for_fraud(pdf_path, model, tokenizer):
250
  import pdfplumber
251
 
252
  with pdfplumber.open(pdf_path) as pdf:
 
254
  for page in pdf.pages:
255
  text += page.extract_text() or ""
256
 
257
+ analyzer = HealthcareFraudAnalyzer(model, tokenizer)
258
  results = analyzer.analyze_document(text)
259
 
260
  analyzer.print_report(results)