llama_4_Medical_Fraud_Detection / document_analyzer.py
Cylanoid's picture
Update document_analyzer.py
19103d4 verified
raw
history blame
10.9 kB
# document_analyzer.py
# Enhanced document analysis module for healthcare fraud detection with Llama 4 (text-only)
import torch
import re
from typing import List, Dict, Any
import nltk
from nltk.tokenize import sent_tokenize
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
class HealthcareFraudAnalyzer:
def __init__(self, model, tokenizer, device=None):
self.model = model
self.tokenizer = tokenizer
self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self.model.eval()
self.fraud_categories = [
"Consent violations",
"Documentation issues",
"Visitation restrictions",
"Medication misuse",
"Chemical restraint",
"Fraudulent billing",
"False testimony",
"Information concealment",
"Patient neglect",
"Hospice certification issues"
]
self.key_terms = {
"medication": ["haloperidol", "lorazepam", "sedation", "chemical", "restraint",
"prn", "as needed", "antipsychotic", "sedative", "benadryl",
"ativan", "seroquel", "comfort kit", "medication"],
"documentation": ["record", "documentation", "log", "chart", "note", "missing",
"altered", "backdated", "omit", "selective", "inconsistent"],
"visitation": ["visit", "restriction", "limit", "family", "spouse", "access",
"barrier", "monitor", "disruptive", "uncooperative"],
"consent": ["consent", "authorize", "approval", "permission", "against wishes",
"refused", "decline", "without knowledge"],
"hospice": ["hospice", "terminal", "end of life", "palliative", "comfort care",
"six months", "6 months", "prognosis", "certification"],
"billing": ["charge", "bill", "payment", "medicare", "medicaid", "insurance",
"reimbursement", "fee", "additional", "extra"]
}
def chunk_document(self, text: str, chunk_size: int = 1024, overlap: int = 256) -> List[str]:
sentences = sent_tokenize(text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= chunk_size:
current_chunk += sentence + " "
else:
chunks.append(current_chunk.strip())
overlap_start = max(0, len(current_chunk) - overlap)
current_chunk = current_chunk[overlap_start:] + sentence + " "
if current_chunk.strip():
chunks.append(current_chunk.strip())
return chunks
def analyze_chunk(self, chunk: str) -> Dict[str, Any]:
prompt = f"""<s>[INST] Analyze the following healthcare document text for evidence of fraud, neglect, abuse, or criminal conduct.
Focus on: {', '.join(self.fraud_categories)}.
Provide specific indicators and cite the relevant text.
DOCUMENT TEXT:
{chunk}
ANALYSIS: [/INST]"""
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).to(self.device)
with torch.no_grad():
output = self.model.generate(
**inputs,
max_new_tokens=512,
temperature=0.1,
top_p=0.9,
repetition_penalty=1.2
)
response = self.tokenizer.decode(output[0], skip_special_tokens=True)
analysis = response.split("ANALYSIS:")[-1].strip()
term_matches = self._find_key_terms(chunk)
return {
"analysis": analysis,
"term_matches": term_matches,
"chunk_text": chunk[:200] + "..." if len(chunk) > 200 else chunk
}
def _find_key_terms(self, text: str) -> Dict[str, List[str]]:
text = text.lower()
results = {}
for category, terms in self.key_terms.items():
matches = []
for term in terms:
pattern = r'.{0,50}' + re.escape(term) + r'.{0,50}'
for match in re.finditer(pattern, text):
matches.append("..." + match.group(0) + "...")
if matches:
results[category] = matches
return results
def analyze_document(self, document_text: str) -> Dict[str, Any]:
document_text = document_text.replace('\n', ' ').replace('\r', ' ')
document_text = re.sub(r'\s+', ' ', document_text)
chunks = self.chunk_document(document_text)
chunk_analyses = [self.analyze_chunk(chunk) for chunk in chunks]
consolidated_findings = self._consolidate_analyses(chunk_analyses)
return {
"summary": self._generate_summary(consolidated_findings, document_text),
"detailed_findings": consolidated_findings,
"chunk_analyses": chunk_analyses,
"document_metadata": {
"length": len(document_text),
"chunk_count": len(chunks)
}
}
def _consolidate_analyses(self, chunk_analyses: List[Dict[str, Any]]) -> Dict[str, Any]:
all_term_matches = {category: [] for category in self.key_terms.keys()}
for analysis in chunk_analyses:
for category, matches in analysis.get("term_matches", {}).items():
all_term_matches[category].extend(matches)
for category in all_term_matches:
if all_term_matches[category]:
deduplicated = []
for match in all_term_matches[category]:
if not any(match in other and match != other for other in all_term_matches[category]):
deduplicated.append(match)
all_term_matches[category] = deduplicated[:5]
categorized_findings = {category: [] for category in self.fraud_categories}
for analysis in chunk_analyses:
analysis_text = analysis.get("analysis", "")
for category in self.fraud_categories:
if category.lower() in analysis_text.lower():
sentences = sent_tokenize(analysis_text)
relevant = [s for s in sentences if category.lower() in s.lower()]
if relevant:
categorized_findings[category].extend(relevant)
return {
"term_matches": all_term_matches,
"categorized_findings": categorized_findings
}
def _generate_summary(self, findings: Dict[str, Any], full_text: str) -> str:
indicator_counts = {
category: len(findings["categorized_findings"].get(category, []))
for category in self.fraud_categories
}
term_match_counts = {
category: len(matches)
for category, matches in findings["term_matches"].items()
}
sorted_categories = sorted(
self.fraud_categories,
key=lambda x: indicator_counts.get(x, 0) + term_match_counts.get(x, 0),
reverse=True
)
summary_lines = ["# Healthcare Fraud Detection Analysis", ""]
summary_lines.append("## Key Concerns Identified")
for category in sorted_categories[:3]:
if indicator_counts.get(category, 0) > 0 or term_match_counts.get(category, 0) > 0:
summary_lines.append(f"### {category}")
if findings["categorized_findings"].get(category):
summary_lines.append("Model analysis indicates:")
for finding in findings["categorized_findings"].get(category, [])[:3]:
summary_lines.append(f"- {finding}")
category_lower = category.lower().rstrip('s')
for term_category, matches in findings["term_matches"].items():
if category_lower in term_category.lower() and matches:
summary_lines.append(f"Key terms identified:")
for match in matches[:3]:
summary_lines.append(f"- {match}")
summary_lines.append("")
summary_lines.append("## Recommended Actions")
if sum(indicator_counts.values()) > 5:
summary_lines.append("- **Urgent review recommended** - Multiple indicators of potential fraud detected")
summary_lines.append("- Consider referral to appropriate regulatory authorities")
summary_lines.append("- Document preservation should be prioritized")
elif sum(indicator_counts.values()) > 2:
summary_lines.append("- **Further investigation recommended** - Several potential indicators identified")
summary_lines.append("- Conduct interviews with involved personnel")
summary_lines.append("- Secure additional documentation for verification")
else:
summary_lines.append("- **Monitor situation** - Limited indicators detected")
summary_lines.append("- Consider more specific document analysis")
return "\n".join(summary_lines)
def print_report(self, results: Dict[str, Any]) -> None:
print("\n" + "="*80)
print("HEALTHCARE FRAUD DETECTION REPORT")
print("="*80 + "\n")
print(results["summary"])
print("\n" + "="*80)
print("DETAILED FINDINGS")
print("="*80)
for category, findings in results["detailed_findings"]["categorized_findings"].items():
if findings:
print(f"\n## {category.upper()}")
for i, finding in enumerate(findings, 1):
print(f"{i}. {finding}")
print("\n" + "="*80)
print("KEY TERM MATCHES")
print("="*80)
for category, matches in results["detailed_findings"]["term_matches"].items():
if matches:
print(f"\n## {category.upper()}")
for match in matches:
print(f"- {match}")
print("\n" + "="*80 + "\n")
def analyze_pdf_for_fraud(pdf_path, model, tokenizer):
import pdfplumber
with pdfplumber.open(pdf_path) as pdf:
text = ""
for page in pdf.pages:
text += page.extract_text() or ""
analyzer = HealthcareFraudAnalyzer(model, tokenizer)
results = analyzer.analyze_document(text)
analyzer.print_report(results)
return results