Cylanoid commited on
Commit
d9cfebf
·
verified ·
1 Parent(s): 1bf4b77

Upload 4 files

Browse files

Files to kick off this new space

Files changed (4) hide show
  1. document_analyzer.py +277 -0
  2. requirements.txt +1 -0
  3. train_llama4.py +128 -0
  4. updated_app.py +272 -0
document_analyzer.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # document_analyzer.py
2
+ # Enhanced document analysis module for healthcare fraud detection with Llama 4
3
+
4
+ import torch
5
+ import re
6
+ from typing import List, Dict, Any
7
+ import nltk
8
+ from nltk.tokenize import sent_tokenize
9
+
10
+ try:
11
+ nltk.data.find('tokenizers/punkt')
12
+ except LookupError:
13
+ nltk.download('punkt')
14
+
15
+ class HealthcareFraudAnalyzer:
16
+ def __init__(self, model, processor, device=None):
17
+ self.model = model
18
+ self.processor = processor
19
+ self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu"
20
+ self.model.to(self.device)
21
+ self.model.eval()
22
+
23
+ self.fraud_categories = [
24
+ "Consent violations",
25
+ "Documentation issues",
26
+ "Visitation restrictions",
27
+ "Medication misuse",
28
+ "Chemical restraint",
29
+ "Fraudulent billing",
30
+ "False testimony",
31
+ "Information concealment",
32
+ "Patient neglect",
33
+ "Hospice certification issues"
34
+ ]
35
+
36
+ self.key_terms = {
37
+ "medication": ["haloperidol", "lorazepam", "sedation", "chemical", "restraint",
38
+ "prn", "as needed", "antipsychotic", "sedative", "benadryl",
39
+ "ativan", "seroquel", "comfort kit", "medication"],
40
+ "documentation": ["record", "documentation", "log", "chart", "note", "missing",
41
+ "altered", "backdated", "omit", "selective", "inconsistent"],
42
+ "visitation": ["visit", "restriction", "limit", "family", "spouse", "access",
43
+ "barrier", "monitor", "disruptive", "uncooperative"],
44
+ "consent": ["consent", "authorize", "approval", "permission", "against wishes",
45
+ "refused", "decline", "without knowledge"],
46
+ "hospice": ["hospice", "terminal", "end of life", "palliative", "comfort care",
47
+ "six months", "6 months", "prognosis", "certification"],
48
+ "billing": ["charge", "bill", "payment", "medicare", "medicaid", "insurance",
49
+ "reimbursement", "fee", "additional", "extra"]
50
+ }
51
+
52
+ def chunk_document(self, text: str, chunk_size: int = 1024, overlap: int = 256) -> List[str]:
53
+ sentences = sent_tokenize(text)
54
+ chunks = []
55
+ current_chunk = ""
56
+
57
+ for sentence in sentences:
58
+ if len(current_chunk) + len(sentence) <= chunk_size:
59
+ current_chunk += sentence + " "
60
+ else:
61
+ chunks.append(current_chunk.strip())
62
+ overlap_start = max(0, len(current_chunk) - overlap)
63
+ current_chunk = current_chunk[overlap_start:] + sentence + " "
64
+
65
+ if current_chunk.strip():
66
+ chunks.append(current_chunk.strip())
67
+
68
+ return chunks
69
+
70
+ def analyze_chunk(self, chunk: str) -> Dict[str, Any]:
71
+ messages = [
72
+ {
73
+ "role": "user",
74
+ "content": [
75
+ {
76
+ "type": "text",
77
+ "text": f"""Analyze the following healthcare document text for evidence of fraud, neglect, abuse, or criminal conduct.
78
+ Focus on: {', '.join(self.fraud_categories)}.
79
+ Provide specific indicators and cite the relevant text.
80
+
81
+ DOCUMENT TEXT:
82
+ {chunk}
83
+
84
+ ANALYSIS:"""
85
+ }
86
+ ]
87
+ }
88
+ ]
89
+
90
+ inputs = self.processor.apply_chat_template(
91
+ messages,
92
+ add_generation_prompt=True,
93
+ tokenize=True,
94
+ return_dict=True,
95
+ return_tensors="pt"
96
+ ).to(self.device)
97
+
98
+ with torch.no_grad():
99
+ output = self.model.generate(
100
+ **inputs,
101
+ max_new_tokens=512,
102
+ temperature=0.1,
103
+ top_p=0.9,
104
+ repetition_penalty=1.2
105
+ )
106
+
107
+ response = self.processor.batch_decode(output[:, inputs["input_ids"].shape[-1]:])[0]
108
+ analysis = response.strip()
109
+
110
+ term_matches = self._find_key_terms(chunk)
111
+
112
+ return {
113
+ "analysis": analysis,
114
+ "term_matches": term_matches,
115
+ "chunk_text": chunk[:200] + "..." if len(chunk) > 200 else chunk
116
+ }
117
+
118
+ def _find_key_terms(self, text: str) -> Dict[str, List[str]]:
119
+ text = text.lower()
120
+ results = {}
121
+
122
+ for category, terms in self.key_terms.items():
123
+ matches = []
124
+ for term in terms:
125
+ pattern = r'.{0,50}' + re.escape(term) + r'.{0,50}'
126
+ for match in re.finditer(pattern, text):
127
+ matches.append("..." + match.group(0) + "...")
128
+
129
+ if matches:
130
+ results[category] = matches
131
+
132
+ return results
133
+
134
+ def analyze_document(self, document_text: str) -> Dict[str, Any]:
135
+ document_text = document_text.replace('\n', ' ').replace('\r', ' ')
136
+ document_text = re.sub(r'\s+', ' ', document_text)
137
+
138
+ chunks = self.chunk_document(document_text)
139
+ chunk_analyses = [self.analyze_chunk(chunk) for chunk in chunks]
140
+ consolidated_findings = self._consolidate_analyses(chunk_analyses)
141
+
142
+ return {
143
+ "summary": self._generate_summary(consolidated_findings, document_text),
144
+ "detailed_findings": consolidated_findings,
145
+ "chunk_analyses": chunk_analyses,
146
+ "document_metadata": {
147
+ "length": len(document_text),
148
+ "chunk_count": len(chunks)
149
+ }
150
+ }
151
+
152
+ def _consolidate_analyses(self, chunk_analyses: List[Dict[str, Any]]) -> Dict[str, Any]:
153
+ all_term_matches = {category: [] for category in self.key_terms.keys()}
154
+
155
+ for analysis in chunk_analyses:
156
+ for category, matches in analysis.get("term_matches", {}).items():
157
+ all_term_matches[category].extend(matches)
158
+
159
+ for category in all_term_matches:
160
+ if all_term_matches[category]:
161
+ deduplicated = []
162
+ for match in all_term_matches[category]:
163
+ if not any(match in other and match != other for other in all_term_matches[category]):
164
+ deduplicated.append(match)
165
+ all_term_matches[category] = deduplicated[:5]
166
+
167
+ categorized_findings = {category: [] for category in self.fraud_categories}
168
+
169
+ for analysis in chunk_analyses:
170
+ analysis_text = analysis.get("analysis", "")
171
+ for category in self.fraud_categories:
172
+ if category.lower() in analysis_text.lower():
173
+ sentences = sent_tokenize(analysis_text)
174
+ relevant = [s for s in sentences if category.lower() in s.lower()]
175
+ if relevant:
176
+ categorized_findings[category].extend(relevant)
177
+
178
+ return {
179
+ "term_matches": all_term_matches,
180
+ "categorized_findings": categorized_findings
181
+ }
182
+
183
+ def _generate_summary(self, findings: Dict[str, Any], full_text: str) -> str:
184
+ indicator_counts = {
185
+ category: len(findings["categorized_findings"].get(category, []))
186
+ for category in self.fraud_categories
187
+ }
188
+
189
+ term_match_counts = {
190
+ category: len(matches)
191
+ for category, matches in findings["term_matches"].items()
192
+ }
193
+
194
+ sorted_categories = sorted(
195
+ self.fraud_categories,
196
+ key=lambda x: indicator_counts.get(x, 0) + term_match_counts.get(x, 0),
197
+ reverse=True
198
+ )
199
+
200
+ summary_lines = ["# Healthcare Fraud Detection Analysis", ""]
201
+ summary_lines.append("## Key Concerns Identified")
202
+
203
+ for category in sorted_categories[:3]:
204
+ if indicator_counts.get(category, 0) > 0 or term_match_counts.get(category, 0) > 0:
205
+ summary_lines.append(f"### {category}")
206
+
207
+ if findings["categorized_findings"].get(category):
208
+ summary_lines.append("Model analysis indicates:")
209
+ for finding in findings["categorized_findings"].get(category, [])[:3]:
210
+ summary_lines.append(f"- {finding}")
211
+
212
+ category_lower = category.lower().rstrip('s')
213
+ for term_category, matches in findings["term_matches"].items():
214
+ if category_lower in term_category.lower() and matches:
215
+ summary_lines.append(f"Key terms identified:")
216
+ for match in matches[:3]:
217
+ summary_lines.append(f"- {match}")
218
+
219
+ summary_lines.append("")
220
+
221
+ summary_lines.append("## Recommended Actions")
222
+ if sum(indicator_counts.values()) > 5:
223
+ summary_lines.append("- **Urgent review recommended** - Multiple indicators of potential fraud detected")
224
+ summary_lines.append("- Consider referral to appropriate regulatory authorities")
225
+ summary_lines.append("- Document preservation should be prioritized")
226
+ elif sum(indicator_counts.values()) > 2:
227
+ summary_lines.append("- **Further investigation recommended** - Several potential indicators identified")
228
+ summary_lines.append("- Conduct interviews with involved personnel")
229
+ summary_lines.append("- Secure additional documentation for verification")
230
+ else:
231
+ summary_lines.append("- **Monitor situation** - Limited indicators detected")
232
+ summary_lines.append("- Consider more specific document analysis")
233
+
234
+ return "\n".join(summary_lines)
235
+
236
+ def print_report(self, results: Dict[str, Any]) -> None:
237
+ print("\n" + "="*80)
238
+ print("HEALTHCARE FRAUD DETECTION REPORT")
239
+ print("="*80 + "\n")
240
+
241
+ print(results["summary"])
242
+
243
+ print("\n" + "="*80)
244
+ print("DETAILED FINDINGS")
245
+ print("="*80)
246
+
247
+ for category, findings in results["detailed_findings"]["categorized_findings"].items():
248
+ if findings:
249
+ print(f"\n## {category.upper()}")
250
+ for i, finding in enumerate(findings, 1):
251
+ print(f"{i}. {finding}")
252
+
253
+ print("\n" + "="*80)
254
+ print("KEY TERM MATCHES")
255
+ print("="*80)
256
+
257
+ for category, matches in results["detailed_findings"]["term_matches"].items():
258
+ if matches:
259
+ print(f"\n## {category.upper()}")
260
+ for match in matches:
261
+ print(f"- {match}")
262
+
263
+ print("\n" + "="*80 + "\n")
264
+
265
+ def analyze_pdf_for_fraud(pdf_path, model, processor):
266
+ import pdfplumber
267
+
268
+ with pdfplumber.open(pdf_path) as pdf:
269
+ text = ""
270
+ for page in pdf.pages:
271
+ text += page.extract_text() or ""
272
+
273
+ analyzer = HealthcareFraudAnalyzer(model, processor)
274
+ results = analyzer.analyze_document(text)
275
+
276
+ analyzer.print_report(results)
277
+ return results
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ torch&gt;=2.0.0 transformers&gt;=4.51.0 datasets&gt;=2.14.0 gradio&gt;=4.0.0 pdfplumber&gt;=0.10.0 peft&gt;=0.14.0 bitsandbytes&gt;=0.41.0 huggingface_hub&gt;=0.19.0 accelerate&gt;=0.21.0 nltk&gt;=3.8.0
train_llama4.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_llama4.py
2
+ # Script to fine-tune Llama 4 Maverick for healthcare fraud detection
3
+
4
+ from transformers import AutoProcessor, Llama4ForConditionalGeneration, Trainer, TrainingArguments
5
+ from transformers import BitsAndBytesConfig
6
+ import datasets
7
+ import torch
8
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
9
+ from accelerate import Accelerator
10
+ import huggingface_hub
11
+ import os
12
+
13
+ # Version and CUDA check
14
+ print(f"PyTorch version: {torch.__version__}")
15
+ print(f"CUDA version: {torch.version.cuda}")
16
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
17
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
18
+
19
+ # Authenticate with Hugging Face
20
+ LLama = os.getenv("LLama")
21
+ if not LLama:
22
+ raise ValueError("LLama token not found. Set it in Hugging Face Space secrets as 'LLama'.")
23
+ huggingface_hub.login(token=LLama)
24
+
25
+ # Load Llama 4 model and processor
26
+ MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
27
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
28
+
29
+ # Quantization config for A100 80 GB VRAM
30
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
31
+
32
+ model = Llama4ForConditionalGeneration.from_pretrained(
33
+ MODEL_ID,
34
+ torch_dtype=torch.bfloat16,
35
+ device_map="auto",
36
+ quantization_config=quantization_config,
37
+ attn_implementation="flex_attention"
38
+ )
39
+
40
+ # Prepare for LoRA
41
+ model = prepare_model_for_kbit_training(model)
42
+ peft_config = LoraConfig(
43
+ r=16,
44
+ lora_alpha=32,
45
+ lora_dropout=0.05,
46
+ bias="none",
47
+ task_type="CAUSAL_LM",
48
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
49
+ )
50
+ model = get_peft_model(model, peft_config)
51
+ model.print_trainable_parameters()
52
+
53
+ # Load dataset
54
+ dataset = datasets.load_dataset("json", data_files="Bingaman_training_data.json", field="training_pairs")
55
+ print("First example from dataset:", dataset["train"][0])
56
+
57
+ # Tokenization
58
+ def tokenize_data(example):
59
+ messages = [
60
+ {
61
+ "role": "user",
62
+ "content": [{"type": "text", "text": example['input']}]
63
+ },
64
+ {
65
+ "role": "assistant",
66
+ "content": [{"type": "text", "text": example['output']}]
67
+ }
68
+ ]
69
+ formatted_text = processor.apply_chat_template(messages, add_generation_prompt=False)
70
+ inputs = processor(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
71
+ input_ids = inputs["input_ids"].squeeze(0).tolist()
72
+ attention_mask = inputs["attention_mask"].squeeze(0).tolist()
73
+ labels = input_ids.copy()
74
+ return {
75
+ "input_ids": input_ids,
76
+ "labels": labels,
77
+ "attention_mask": attention_mask
78
+ }
79
+
80
+ tokenized_dataset = dataset["train"].map(tokenize_data, batched=False, remove_columns=dataset["train"].column_names)
81
+ print("First tokenized example:", {k: (type(v), len(v)) for k, v in tokenized_dataset[0].items()})
82
+
83
+ # Data collator
84
+ def custom_data_collator(features):
85
+ input_ids = [torch.tensor(f["input_ids"]) for f in features]
86
+ attention_mask = [torch.tensor(f["attention_mask"]) for f in features]
87
+ labels = [torch.tensor(f["labels"]) for f in features]
88
+ return {
89
+ "input_ids": torch.stack(input_ids),
90
+ "attention_mask": torch.stack(attention_mask),
91
+ "labels": torch.stack(labels)
92
+ }
93
+
94
+ # Training setup
95
+ accelerator = Accelerator()
96
+ training_args = TrainingArguments(
97
+ output_dir="./fine_tuned_llama4_healthcare",
98
+ per_device_train_batch_size=2,
99
+ gradient_accumulation_steps=8,
100
+ eval_strategy="steps",
101
+ eval_steps=10,
102
+ save_strategy="steps",
103
+ save_steps=20,
104
+ save_total_limit=3,
105
+ num_train_epochs=5,
106
+ learning_rate=2e-5,
107
+ weight_decay=0.01,
108
+ logging_dir="./logs",
109
+ logging_steps=5,
110
+ bf16=True,
111
+ gradient_checkpointing=True,
112
+ optim="adamw_torch",
113
+ warmup_steps=50
114
+ )
115
+
116
+ trainer = Trainer(
117
+ model=model,
118
+ args=training_args,
119
+ train_dataset=tokenized_dataset,
120
+ eval_dataset=tokenized_dataset.select(range(min(5, len(tokenized_dataset)))),
121
+ data_collator=custom_data_collator
122
+ )
123
+
124
+ # Start training
125
+ trainer.train()
126
+ model.save_pretrained("./fine_tuned_llama4_healthcare")
127
+ processor.save_pretrained("./fine_tuned_llama4_healthcare")
128
+ print("Training complete. Model and processor saved to ./fine_tuned_llama4_healthcare")
updated_app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # updated_app.py
2
+ # Enhanced Gradio app for Llama 4 Maverick healthcare fraud detection
3
+
4
+ import gradio as gr
5
+ from transformers import AutoProcessor, Llama4ForConditionalGeneration
6
+ import datasets
7
+ import torch
8
+ import json
9
+ import os
10
+ import pdfplumber
11
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
12
+ from accelerate import Accelerator
13
+ import huggingface_hub
14
+ import re
15
+ import nltk
16
+ from nltk.tokenize import sent_tokenize
17
+
18
+ try:
19
+ nltk.data.find('tokenizers/punkt')
20
+ except LookupError:
21
+ nltk.download('punkt')
22
+
23
+ # Import the HealthcareFraudAnalyzer
24
+ from document_analyzer import HealthcareFraudAnalyzer
25
+
26
+ # Debug: Print environment variables to verify 'LLama' is present
27
+ print("Environment variables:", dict(os.environ))
28
+
29
+ # Retrieve the token from Hugging Face Space secrets
30
+ LLama = os.getenv("LLama")
31
+ if not LLama:
32
+ raise ValueError("LLama token not found. Set it in Hugging Face Space secrets as 'LLama'.")
33
+
34
+ # Debug: Print token (first 5 chars for security, remove in production)
35
+ print(f"Retrieved LLama token: {LLama[:5]}...")
36
+
37
+ # Authenticate with Hugging Face
38
+ huggingface_hub.login(token=LLama)
39
+
40
+ # Model setup
41
+ MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
42
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
43
+
44
+ # Load model with FP8 quantization to fit in 80 GB VRAM
45
+ model = Llama4ForConditionalGeneration.from_pretrained(
46
+ MODEL_ID,
47
+ torch_dtype=torch.bfloat16,
48
+ device_map="auto",
49
+ quantization_config={"load_in_8bit": True},
50
+ attn_implementation="flex_attention"
51
+ )
52
+
53
+ # Prepare model for LoRA training
54
+ model = prepare_model_for_kbit_training(model)
55
+ peft_config = LoraConfig(
56
+ r=16,
57
+ lora_alpha=32,
58
+ lora_dropout=0.05,
59
+ bias="none",
60
+ task_type="CAUSAL_LM",
61
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
62
+ )
63
+ model = get_peft_model(model, peft_config)
64
+ model.print_trainable_parameters()
65
+
66
+ # Function to create training pairs from document text
67
+ def extract_training_pairs_from_text(text):
68
+ pairs = []
69
+ patterns = [
70
+ # Medication patterns
71
+ (
72
+ r"(?i).*?\b(haloperidol|lorazepam|ativan)\b.*?\b(daily|routine|regular)\b.*?",
73
+ "Patient receives {} on a {} basis. Is this appropriate medication management?",
74
+ "This may indicate inappropriate medication management. Regular use of psychotropic medications without documented need assessment, behavior monitoring, and attempted dose reductions may violate care standards."
75
+ ),
76
+ # Documentation patterns
77
+ (
78
+ r"(?i).*?\b(missing|omitted|absent|lacking)\b.*?\b(documentation|records|logs|notes)\b.*?",
79
+ "Facility has {} {} for patient care. Is this a documentation concern?",
80
+ "Yes, incomplete documentation is a significant red flag. Missing records may indicate attempts to conceal care issues or fraudulent billing for services not provided."
81
+ ),
82
+ # Visitation patterns
83
+ (
84
+ r"(?i).*?\b(restrict|limit|prevent|block)\b.*?\b(visits|visitation|access|family)\b.*?",
85
+ "Facility {} family {} without documented medical necessity. Is this suspicious?",
86
+ "Yes, unjustified visitation restrictions may indicate attempts to conceal care issues and prevent family oversight. This can constitute fraud when facilities bill for care while violating resident rights."
87
+ ),
88
+ # Hospice patterns
89
+ (
90
+ r"(?i).*?\b(hospice|terminal|end.of.life)\b.*?\b(not|without|lacking)\b.*?\b(evidence|decline|documentation)\b.*?",
91
+ "Patient placed on {} care {} supporting {}. Is this fraudulent?",
92
+ "Yes, hospice enrollment without documented terminal decline may indicate Medicare fraud. Hospice certification requires genuine clinical determination of terminal status with prognosis of six months or less."
93
+ ),
94
+ # Contradictory documentation
95
+ (
96
+ r"(?i).*?\b(different|contradicts|conflicts|inconsistent)\b.*?\b(records|documentation|testimony|statements)\b.*?",
97
+ "Records show {} {} about patient condition. Is this fraudulent documentation?",
98
+ "Yes, contradictory documentation is a strong indicator of fraudulent record-keeping designed to misrepresent care quality or patient condition, particularly when official records differ from internal communications."
99
+ )
100
+ ]
101
+
102
+ for pattern, input_template, output_text in patterns:
103
+ matches = re.finditer(pattern, text)
104
+ for match in matches:
105
+ groups = match.groups()
106
+ if len(groups) >= 2:
107
+ input_text = input_template.format(*groups)
108
+ pairs.append({
109
+ "input": input_text,
110
+ "output": output_text
111
+ })
112
+
113
+ if not pairs:
114
+ if any(x in text.lower() for x in ["medication", "prescribed", "administered"]):
115
+ pairs.append({
116
+ "input": "Medication records show inconsistencies in administration times. Is this concerning?",
117
+ "output": "Yes, inconsistent medication administration timing may indicate fraudulent documentation or medication mismanagement that could harm patients."
118
+ })
119
+ if any(x in text.lower() for x in ["visit", "family", "spouse"]):
120
+ pairs.append({
121
+ "input": "Staff documents family visits inconsistently. Is this suspicious?",
122
+ "output": "Yes, selective documentation of family visits indicates fraudulent record-keeping designed to create a false narrative about family involvement and patient responses."
123
+ })
124
+ if any(x in text.lower() for x in ["hospice", "terminal", "prognosis"]):
125
+ pairs.append({
126
+ "input": "Patient remained on hospice for extended period without documented decline. Is this Medicare fraud?",
127
+ "output": "Yes, maintaining hospice services without documented decline suggests fraudulent hospice certification to obtain Medicare benefits inappropriately."
128
+ })
129
+
130
+ return pairs
131
+
132
+ # Function to process uploaded files and train
133
+ def train_ui(files):
134
+ try:
135
+ raw_text = ""
136
+ dataset = None
137
+ for file in files:
138
+ if file.name.endswith(".pdf"):
139
+ with pdfplumber.open(file.name) as pdf:
140
+ for page in pdf.pages:
141
+ raw_text += page.extract_text() or ""
142
+ elif file.name.endswith(".json"):
143
+ with open(file.name, "r", encoding="utf-8") as f:
144
+ raw_data = json.load(f)
145
+ training_data = raw_data.get("training_pairs", raw_data)
146
+ with open("temp_fraud_data.json", "w", encoding="utf-8") as f:
147
+ json.dump({"training_pairs": training_data}, f)
148
+ dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")
149
+
150
+ if not raw_text and not dataset:
151
+ return "Error: No valid PDF or JSON data found."
152
+
153
+ if raw_text:
154
+ training_data = extract_training_pairs_from_text(raw_text)
155
+ with open("temp_fraud_data.json", "w") as f:
156
+ json.dump({"training_pairs": training_data}, f)
157
+ dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")
158
+
159
+ def tokenize_data(example):
160
+ messages = [
161
+ {
162
+ "role": "user",
163
+ "content": [{"type": "text", "text": example['input']}]
164
+ },
165
+ {
166
+ "role": "assistant",
167
+ "content": [{"type": "text", "text": example['output']}]
168
+ }
169
+ ]
170
+ formatted_text = processor.apply_chat_template(messages, add_generation_prompt=False)
171
+ inputs = processor(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
172
+ inputs["labels"] = inputs["input_ids"].clone()
173
+ return {k: v.squeeze(0) for k, v in inputs.items()}
174
+
175
+ tokenized_dataset = dataset["train"].map(tokenize_data, batched=True, remove_columns=dataset["train"].column_names)
176
+
177
+ training_args = TrainingArguments(
178
+ output_dir="./fine_tuned_llama4_healthcare",
179
+ per_device_train_batch_size=2,
180
+ gradient_accumulation_steps=8,
181
+ eval_strategy="no",
182
+ save_strategy="epoch",
183
+ save_total_limit=2,
184
+ num_train_epochs=5,
185
+ learning_rate=2e-5,
186
+ weight_decay=0.01,
187
+ logging_dir="./logs",
188
+ logging_steps=10,
189
+ bf16=True,
190
+ gradient_checkpointing=True,
191
+ optim="adamw_torch",
192
+ warmup_steps=100,
193
+ )
194
+
195
+ def custom_data_collator(features):
196
+ return {
197
+ "input_ids": torch.stack([f["input_ids"] for f in features]),
198
+ "attention_mask": torch.stack([f["attention_mask"] for f in features]),
199
+ "labels": torch.stack([f["labels"] for f in features]),
200
+ }
201
+
202
+ trainer = Trainer(
203
+ model=model,
204
+ args=training_args,
205
+ train_dataset=tokenized_dataset,
206
+ data_collator=custom_data_collator,
207
+ )
208
+
209
+ trainer.train()
210
+ model.save_pretrained("./fine_tuned_llama4_healthcare")
211
+ processor.save_pretrained("./fine_tuned_llama4_healthcare")
212
+ return f"Training completed with {len(tokenized_dataset)} examples! Model saved to ./fine_tuned_llama4_healthcare"
213
+
214
+ except Exception as e:
215
+ return f"Error: {str(e)}. Please check file format, dependencies, or the LLama token."
216
+
217
+ # Function to analyze uploaded document for fraud
218
+ def analyze_document_ui(files):
219
+ try:
220
+ if not files:
221
+ return "Error: No file uploaded. Please upload a PDF to analyze."
222
+
223
+ file = files[0]
224
+ if not file.name.endswith(".pdf"):
225
+ return "Error: Please upload a PDF file for analysis."
226
+
227
+ raw_text = ""
228
+ with pdfplumber.open(file.name) as pdf:
229
+ for page in pdf.pages:
230
+ raw_text += page.extract_text() or ""
231
+
232
+ if not raw_text:
233
+ return "Error: Could not extract text from the PDF. The file may be corrupt or contain only images."
234
+
235
+ analyzer = HealthcareFraudAnalyzer(model, processor)
236
+ results = analyzer.analyze_document(raw_text)
237
+ return results["summary"]
238
+
239
+ except Exception as e:
240
+ return f"Error during document analysis: {str(e)}"
241
+
242
+ # Gradio UI with training and analysis tabs
243
+ with gr.Blocks(title="Healthcare Fraud Detection Suite") as demo:
244
+ gr.Markdown("# Healthcare Fraud Detection Suite")
245
+
246
+ with gr.Tabs():
247
+ with gr.TabItem("Fine-Tune Model"):
248
+ gr.Markdown("## Train Llama 4 for Healthcare Fraud Detection")
249
+ gr.Markdown("Upload PDFs (e.g., care logs, medication records) or a JSON file with training pairs.")
250
+ train_file_input = gr.File(label="Upload Files (PDF/JSON)", file_count="multiple")
251
+ train_button = gr.Button("Start Fine-Tuning")
252
+ train_output = gr.Textbox(label="Training Status", lines=5)
253
+ train_button.click(fn=train_ui, inputs=train_file_input, outputs=train_output)
254
+
255
+ with gr.TabItem("Analyze Document"):
256
+ gr.Markdown("## Analyze Document for Healthcare Fraud Indicators")
257
+ gr.Markdown("Upload a PDF document to analyze for potential fraud, neglect, or abuse indicators.")
258
+ analyze_file_input = gr.File(label="Upload PDF Document")
259
+ analyze_button = gr.Button("Analyze Document")
260
+ analyze_output = gr.Markdown(label="Analysis Results")
261
+ analyze_button.click(fn=analyze_document_ui, inputs=analyze_file_input, outputs=analyze_output)
262
+
263
+ gr.Markdown("""
264
+ ### About This Tool
265
+ This tool uses Llama 4 Maverick to identify patterns of potential fraud, neglect, and abuse in healthcare documentation.
266
+ The fine-tuning tab allows model customization with your examples or automatic extraction from documents.
267
+ The analysis tab scans documents for suspicious patterns, generating detailed reports.
268
+ **Note:** All analysis is performed locally - no data is shared externally.
269
+ """)
270
+
271
+ # Launch the Gradio app
272
+ demo.launch()