File size: 12,315 Bytes
d9cfebf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
# updated_app.py
# Enhanced Gradio app for Llama 4 Maverick healthcare fraud detection
import gradio as gr
from transformers import AutoProcessor, Llama4ForConditionalGeneration
import datasets
import torch
import json
import os
import pdfplumber
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate import Accelerator
import huggingface_hub
import re
import nltk
from nltk.tokenize import sent_tokenize
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
# Import the HealthcareFraudAnalyzer
from document_analyzer import HealthcareFraudAnalyzer
# Debug: Print environment variables to verify 'LLama' is present
print("Environment variables:", dict(os.environ))
# Retrieve the token from Hugging Face Space secrets
LLama = os.getenv("LLama")
if not LLama:
raise ValueError("LLama token not found. Set it in Hugging Face Space secrets as 'LLama'.")
# Debug: Print token (first 5 chars for security, remove in production)
print(f"Retrieved LLama token: {LLama[:5]}...")
# Authenticate with Hugging Face
huggingface_hub.login(token=LLama)
# Model setup
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
# Load model with FP8 quantization to fit in 80 GB VRAM
model = Llama4ForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config={"load_in_8bit": True},
attn_implementation="flex_attention"
)
# Prepare model for LoRA training
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# Function to create training pairs from document text
def extract_training_pairs_from_text(text):
pairs = []
patterns = [
# Medication patterns
(
r"(?i).*?\b(haloperidol|lorazepam|ativan)\b.*?\b(daily|routine|regular)\b.*?",
"Patient receives {} on a {} basis. Is this appropriate medication management?",
"This may indicate inappropriate medication management. Regular use of psychotropic medications without documented need assessment, behavior monitoring, and attempted dose reductions may violate care standards."
),
# Documentation patterns
(
r"(?i).*?\b(missing|omitted|absent|lacking)\b.*?\b(documentation|records|logs|notes)\b.*?",
"Facility has {} {} for patient care. Is this a documentation concern?",
"Yes, incomplete documentation is a significant red flag. Missing records may indicate attempts to conceal care issues or fraudulent billing for services not provided."
),
# Visitation patterns
(
r"(?i).*?\b(restrict|limit|prevent|block)\b.*?\b(visits|visitation|access|family)\b.*?",
"Facility {} family {} without documented medical necessity. Is this suspicious?",
"Yes, unjustified visitation restrictions may indicate attempts to conceal care issues and prevent family oversight. This can constitute fraud when facilities bill for care while violating resident rights."
),
# Hospice patterns
(
r"(?i).*?\b(hospice|terminal|end.of.life)\b.*?\b(not|without|lacking)\b.*?\b(evidence|decline|documentation)\b.*?",
"Patient placed on {} care {} supporting {}. Is this fraudulent?",
"Yes, hospice enrollment without documented terminal decline may indicate Medicare fraud. Hospice certification requires genuine clinical determination of terminal status with prognosis of six months or less."
),
# Contradictory documentation
(
r"(?i).*?\b(different|contradicts|conflicts|inconsistent)\b.*?\b(records|documentation|testimony|statements)\b.*?",
"Records show {} {} about patient condition. Is this fraudulent documentation?",
"Yes, contradictory documentation is a strong indicator of fraudulent record-keeping designed to misrepresent care quality or patient condition, particularly when official records differ from internal communications."
)
]
for pattern, input_template, output_text in patterns:
matches = re.finditer(pattern, text)
for match in matches:
groups = match.groups()
if len(groups) >= 2:
input_text = input_template.format(*groups)
pairs.append({
"input": input_text,
"output": output_text
})
if not pairs:
if any(x in text.lower() for x in ["medication", "prescribed", "administered"]):
pairs.append({
"input": "Medication records show inconsistencies in administration times. Is this concerning?",
"output": "Yes, inconsistent medication administration timing may indicate fraudulent documentation or medication mismanagement that could harm patients."
})
if any(x in text.lower() for x in ["visit", "family", "spouse"]):
pairs.append({
"input": "Staff documents family visits inconsistently. Is this suspicious?",
"output": "Yes, selective documentation of family visits indicates fraudulent record-keeping designed to create a false narrative about family involvement and patient responses."
})
if any(x in text.lower() for x in ["hospice", "terminal", "prognosis"]):
pairs.append({
"input": "Patient remained on hospice for extended period without documented decline. Is this Medicare fraud?",
"output": "Yes, maintaining hospice services without documented decline suggests fraudulent hospice certification to obtain Medicare benefits inappropriately."
})
return pairs
# Function to process uploaded files and train
def train_ui(files):
try:
raw_text = ""
dataset = None
for file in files:
if file.name.endswith(".pdf"):
with pdfplumber.open(file.name) as pdf:
for page in pdf.pages:
raw_text += page.extract_text() or ""
elif file.name.endswith(".json"):
with open(file.name, "r", encoding="utf-8") as f:
raw_data = json.load(f)
training_data = raw_data.get("training_pairs", raw_data)
with open("temp_fraud_data.json", "w", encoding="utf-8") as f:
json.dump({"training_pairs": training_data}, f)
dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")
if not raw_text and not dataset:
return "Error: No valid PDF or JSON data found."
if raw_text:
training_data = extract_training_pairs_from_text(raw_text)
with open("temp_fraud_data.json", "w") as f:
json.dump({"training_pairs": training_data}, f)
dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")
def tokenize_data(example):
messages = [
{
"role": "user",
"content": [{"type": "text", "text": example['input']}]
},
{
"role": "assistant",
"content": [{"type": "text", "text": example['output']}]
}
]
formatted_text = processor.apply_chat_template(messages, add_generation_prompt=False)
inputs = processor(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
inputs["labels"] = inputs["input_ids"].clone()
return {k: v.squeeze(0) for k, v in inputs.items()}
tokenized_dataset = dataset["train"].map(tokenize_data, batched=True, remove_columns=dataset["train"].column_names)
training_args = TrainingArguments(
output_dir="./fine_tuned_llama4_healthcare",
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
eval_strategy="no",
save_strategy="epoch",
save_total_limit=2,
num_train_epochs=5,
learning_rate=2e-5,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
bf16=True,
gradient_checkpointing=True,
optim="adamw_torch",
warmup_steps=100,
)
def custom_data_collator(features):
return {
"input_ids": torch.stack([f["input_ids"] for f in features]),
"attention_mask": torch.stack([f["attention_mask"] for f in features]),
"labels": torch.stack([f["labels"] for f in features]),
}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=custom_data_collator,
)
trainer.train()
model.save_pretrained("./fine_tuned_llama4_healthcare")
processor.save_pretrained("./fine_tuned_llama4_healthcare")
return f"Training completed with {len(tokenized_dataset)} examples! Model saved to ./fine_tuned_llama4_healthcare"
except Exception as e:
return f"Error: {str(e)}. Please check file format, dependencies, or the LLama token."
# Function to analyze uploaded document for fraud
def analyze_document_ui(files):
try:
if not files:
return "Error: No file uploaded. Please upload a PDF to analyze."
file = files[0]
if not file.name.endswith(".pdf"):
return "Error: Please upload a PDF file for analysis."
raw_text = ""
with pdfplumber.open(file.name) as pdf:
for page in pdf.pages:
raw_text += page.extract_text() or ""
if not raw_text:
return "Error: Could not extract text from the PDF. The file may be corrupt or contain only images."
analyzer = HealthcareFraudAnalyzer(model, processor)
results = analyzer.analyze_document(raw_text)
return results["summary"]
except Exception as e:
return f"Error during document analysis: {str(e)}"
# Gradio UI with training and analysis tabs
with gr.Blocks(title="Healthcare Fraud Detection Suite") as demo:
gr.Markdown("# Healthcare Fraud Detection Suite")
with gr.Tabs():
with gr.TabItem("Fine-Tune Model"):
gr.Markdown("## Train Llama 4 for Healthcare Fraud Detection")
gr.Markdown("Upload PDFs (e.g., care logs, medication records) or a JSON file with training pairs.")
train_file_input = gr.File(label="Upload Files (PDF/JSON)", file_count="multiple")
train_button = gr.Button("Start Fine-Tuning")
train_output = gr.Textbox(label="Training Status", lines=5)
train_button.click(fn=train_ui, inputs=train_file_input, outputs=train_output)
with gr.TabItem("Analyze Document"):
gr.Markdown("## Analyze Document for Healthcare Fraud Indicators")
gr.Markdown("Upload a PDF document to analyze for potential fraud, neglect, or abuse indicators.")
analyze_file_input = gr.File(label="Upload PDF Document")
analyze_button = gr.Button("Analyze Document")
analyze_output = gr.Markdown(label="Analysis Results")
analyze_button.click(fn=analyze_document_ui, inputs=analyze_file_input, outputs=analyze_output)
gr.Markdown("""
### About This Tool
This tool uses Llama 4 Maverick to identify patterns of potential fraud, neglect, and abuse in healthcare documentation.
The fine-tuning tab allows model customization with your examples or automatic extraction from documents.
The analysis tab scans documents for suspicious patterns, generating detailed reports.
**Note:** All analysis is performed locally - no data is shared externally.
""")
# Launch the Gradio app
demo.launch() |