File size: 12,112 Bytes
2cf39df
50c6bec
2cf39df
1bf4b77
50c6bec
2cf39df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50c6bec
2cf39df
50c6bec
 
 
 
 
2cf39df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50c6bec
 
2cf39df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50c6bec
2cf39df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50c6bec
2cf39df
 
 
 
 
1bf4b77
2cf39df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bf4b77
2cf39df
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
# updated_app.py
# Enhanced Gradio app for Llama 4 Maverick healthcare fraud detection (text-only)

import gradio as gr
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
import datasets
import torch
import json
import os
import pdfplumber
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate import Accelerator
import huggingface_hub
import re
import nltk
from nltk.tokenize import sent_tokenize

try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# Import the HealthcareFraudAnalyzer
from document_analyzer import HealthcareFraudAnalyzer

# Debug: Print environment variables to verify 'LLama' is present
print("Environment variables:", dict(os.environ))

# Retrieve the token from Hugging Face Space secrets
LLama = os.getenv("LLama")
if not LLama:
    raise ValueError("LLama token not found. Set it in Hugging Face Space secrets as 'LLama'.")

# Debug: Print token (first 5 chars for security, remove in production)
print(f"Retrieved LLama token: {LLama[:5]}...")

# Authenticate with Hugging Face
huggingface_hub.login(token=LLama)

# Model setup
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

# Add padding token if it doesn't exist
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Load model with 8-bit quantization to fit in 80 GB VRAM
model = Llama4ForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config={"load_in_8bit": True},
    attn_implementation="flex_attention"
)

# Prepare model for LoRA training
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# Function to create training pairs from document text
def extract_training_pairs_from_text(text):
    pairs = []
    patterns = [
        # Medication patterns
        (
            r"(?i).*?\b(haloperidol|lorazepam|ativan)\b.*?\b(daily|routine|regular)\b.*?",
            "Patient receives {} on a {} basis. Is this appropriate medication management?",
            "This may indicate inappropriate medication management. Regular use of psychotropic medications without documented need assessment, behavior monitoring, and attempted dose reductions may violate care standards."
        ),
        # Documentation patterns
        (
            r"(?i).*?\b(missing|omitted|absent|lacking)\b.*?\b(documentation|records|logs|notes)\b.*?",
            "Facility has {} {} for patient care. Is this a documentation concern?",
            "Yes, incomplete documentation is a significant red flag. Missing records may indicate attempts to conceal care issues or fraudulent billing for services not provided."
        ),
        # Visitation patterns
        (
            r"(?i).*?\b(restrict|limit|prevent|block)\b.*?\b(visits|visitation|access|family)\b.*?",
            "Facility {} family {} without documented medical necessity. Is this suspicious?",
            "Yes, unjustified visitation restrictions may indicate attempts to conceal care issues and prevent family oversight. This can constitute fraud when facilities bill for care while violating resident rights."
        ),
        # Hospice patterns
        (
            r"(?i).*?\b(hospice|terminal|end.of.life)\b.*?\b(not|without|lacking)\b.*?\b(evidence|decline|documentation)\b.*?",
            "Patient placed on {} care {} supporting {}. Is this fraudulent?",
            "Yes, hospice enrollment without documented terminal decline may indicate Medicare fraud. Hospice certification requires genuine clinical determination of terminal status with prognosis of six months or less."
        ),
        # Contradictory documentation
        (
            r"(?i).*?\b(different|contradicts|conflicts|inconsistent)\b.*?\b(records|documentation|testimony|statements)\b.*?",
            "Records show {} {} about patient condition. Is this fraudulent documentation?",
            "Yes, contradictory documentation is a strong indicator of fraudulent record-keeping designed to misrepresent care quality or patient condition, particularly when official records differ from internal communications."
        )
    ]

    for pattern, input_template, output_text in patterns:
        matches = re.finditer(pattern, text)
        for match in matches:
            groups = match.groups()
            if len(groups) >= 2:
                input_text = input_template.format(*groups)
                pairs.append({
                    "input": input_text,
                    "output": output_text
                })

    if not pairs:
        if any(x in text.lower() for x in ["medication", "prescribed", "administered"]):
            pairs.append({
                "input": "Medication records show inconsistencies in administration times. Is this concerning?",
                "output": "Yes, inconsistent medication administration timing may indicate fraudulent documentation or medication mismanagement that could harm patients."
            })
        if any(x in text.lower() for x in ["visit", "family", "spouse"]):
            pairs.append({
                "input": "Staff documents family visits inconsistently. Is this suspicious?",
                "output": "Yes, selective documentation of family visits indicates fraudulent record-keeping designed to create a false narrative about family involvement and patient responses."
            })
        if any(x in text.lower() for x in ["hospice", "terminal", "prognosis"]):
            pairs.append({
                "input": "Patient remained on hospice for extended period without documented decline. Is this Medicare fraud?",
                "output": "Yes, maintaining hospice services without documented decline suggests fraudulent hospice certification to obtain Medicare benefits inappropriately."
            })

    return pairs

# Function to process uploaded files and train
def train_ui(files):
    try:
        raw_text = ""
        dataset = None
        for file in files:
            if file.name.endswith(".pdf"):
                with pdfplumber.open(file.name) as pdf:
                    for page in pdf.pages:
                        raw_text += page.extract_text() or ""
            elif file.name.endswith(".json"):
                with open(file.name, "r", encoding="utf-8") as f:
                    raw_data = json.load(f)
                    training_data = raw_data.get("training_pairs", raw_data)
                    with open("temp_fraud_data.json", "w", encoding="utf-8") as f:
                        json.dump({"training_pairs": training_data}, f)
                    dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")

        if not raw_text and not dataset:
            return "Error: No valid PDF or JSON data found."

        if raw_text:
            training_data = extract_training_pairs_from_text(raw_text)
            with open("temp_fraud_data.json", "w") as f:
                json.dump({"training_pairs": training_data}, f)
            dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")

        def tokenize_data(example):
            formatted_text = f"<s>[INST] {example['input']} [/INST] {example['output']}</s>"
            inputs = tokenizer(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
            inputs["labels"] = inputs["input_ids"].clone()
            return {k: v.squeeze(0) for k, v in inputs.items()}

        tokenized_dataset = dataset["train"].map(tokenize_data, batched=True, remove_columns=dataset["train"].column_names)

        training_args = TrainingArguments(
            output_dir="./fine_tuned_llama4_healthcare",
            per_device_train_batch_size=2,
            gradient_accumulation_steps=8,
            eval_strategy="no",
            save_strategy="epoch",
            save_total_limit=2,
            num_train_epochs=5,
            learning_rate=2e-5,
            weight_decay=0.01,
            logging_dir="./logs",
            logging_steps=10,
            bf16=True,
            gradient_checkpointing=True,
            optim="adamw_torch",
            warmup_steps=100,
        )

        def custom_data_collator(features):
            return {
                "input_ids": torch.stack([f["input_ids"] for f in features]),
                "attention_mask": torch.stack([f["attention_mask"] for f in features]),
                "labels": torch.stack([f["labels"] for f in features]),
            }

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_dataset,
            data_collator=custom_data_collator,
        )

        trainer.train()
        model.save_pretrained("./fine_tuned_llama4_healthcare")
        tokenizer.save_pretrained("./fine_tuned_llama4_healthcare")
        return f"Training completed with {len(tokenized_dataset)} examples! Model saved to ./fine_tuned_llama4_healthcare"

    except Exception as e:
        return f"Error: {str(e)}. Please check file format, dependencies, or the LLama token."

# Function to analyze uploaded document for fraud
def analyze_document_ui(files):
    try:
        if not files:
            return "Error: No file uploaded. Please upload a PDF to analyze."
        
        file = files[0]
        if not file.name.endswith(".pdf"):
            return "Error: Please upload a PDF file for analysis."
        
        raw_text = ""
        with pdfplumber.open(file.name) as pdf:
            for page in pdf.pages:
                raw_text += page.extract_text() or ""
        
        if not raw_text:
            return "Error: Could not extract text from the PDF. The file may be corrupt or contain only images."
        
        analyzer = HealthcareFraudAnalyzer(model, tokenizer)
        results = analyzer.analyze_document(raw_text)
        return results["summary"]
    
    except Exception as e:
        return f"Error during document analysis: {str(e)}"

# Gradio UI with training and analysis tabs
with gr.Blocks(title="Healthcare Fraud Detection Suite") as demo:
    gr.Markdown("# Healthcare Fraud Detection Suite")
    
    with gr.Tabs():
        with gr.TabItem("Fine-Tune Model"):
            gr.Markdown("## Train Llama 4 for Healthcare Fraud Detection")
            gr.Markdown("Upload PDFs (e.g., care logs, medication records) or a JSON file with training pairs.")
            train_file_input = gr.File(label="Upload Files (PDF/JSON)", file_count="multiple")
            train_button = gr.Button("Start Fine-Tuning")
            train_output = gr.Textbox(label="Training Status", lines=5)
            train_button.click(fn=train_ui, inputs=train_file_input, outputs=train_output)
        
        with gr.TabItem("Analyze Document"):
            gr.Markdown("## Analyze Document for Healthcare Fraud Indicators")
            gr.Markdown("Upload a PDF document to analyze for potential fraud, neglect, or abuse indicators.")
            analyze_file_input = gr.File(label="Upload PDF Document")
            analyze_button = gr.Button("Analyze Document")
            analyze_output = gr.Markdown(label="Analysis Results")
            analyze_button.click(fn=analyze_document_ui, inputs=analyze_file_input, outputs=analyze_output)
    
    gr.Markdown("""
    ### About This Tool
    This tool uses Llama 4 Maverick to identify patterns of potential fraud, neglect, and abuse in healthcare documentation.
    The fine-tuning tab allows model customization with your examples or automatic extraction from documents.
    The analysis tab scans documents for suspicious patterns, generating detailed reports.
    **Note:** All analysis is performed locally - no data is shared externally.
    """)

# Launch the Gradio app
demo.launch()