File size: 5,934 Bytes
6f2b1d7
9b2c756
2cf39df
1bf4b77
fab7ed8
2cf39df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36b5bed
fab7ed8
36b5bed
9b2c756
2cf39df
 
9b2c756
2cf39df
 
 
 
9b2c756
2cf39df
 
 
 
 
 
 
50c6bec
2cf39df
50c6bec
 
 
fab7ed8
 
 
 
 
 
 
6f2b1d7
 
fab7ed8
 
6f2b1d7
 
 
 
5997cdc
fab7ed8
5997cdc
36b5bed
fab7ed8
 
 
 
 
 
 
 
 
 
 
2cf39df
36b5bed
 
 
 
 
 
 
 
 
 
 
 
2cf39df
36b5bed
 
 
 
 
 
 
 
 
 
2cf39df
 
36b5bed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cf39df
 
 
36b5bed
 
2cf39df
36b5bed
2cf39df
36b5bed
 
2cf39df
36b5bed
 
2cf39df
36b5bed
2cf39df
36b5bed
 
2cf39df
36b5bed
 
 
 
 
 
 
2cf39df
36b5bed
1bf4b77
36b5bed
 
 
2cf39df
36b5bed
 
 
 
 
 
 
 
 
 
 
2cf39df
36b5bed
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# app.py
# Gradio app for Llama 4 Maverick healthcare fraud detection (text-only with CPU offloading)

import gradio as gr
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, BitsAndBytesConfig
import datasets
import torch
import json
import os
import pdfplumber
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate import Accelerator
import huggingface_hub
import re
import nltk
from nltk.tokenize import sent_tokenize

try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# Import the HealthcareFraudAnalyzer
from document_analyzer import HealthcareFraudAnalyzer

# Debug: Confirm file version
print("Running updated app.py with CPU offloading (version: 2025-04-21 v3)")

# Debug: Print environment variables
print("Environment variables:", dict(os.environ))

# Retrieve the token from secrets
LLama = os.getenv("LLama")
if not LLama:
    raise ValueError("LLama token not found. Set it in Hugging Face Space secrets as 'LLama'.")

# Debug: Print token (first 5 chars)
print(f"Retrieved LLama token: {LLama[:5]}...")

# Authenticate with Hugging Face
huggingface_hub.login(token=LLama)

# Model setup
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Explicit quantization configuration
quant_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_enable_fp32_cpu_offload=True
)

# Custom device map for CPU offloading (more layers to CPU)
device_map = {
    "model.embed_tokens": 0,
    "model.layers.0-10": 0,  # First 11 layers on GPU
    "model.layers.11-31": "cpu",  # Remaining layers on CPU
    "model.norm": 0,
    "lm_head": 0
}

# Debug: Confirm offloading settings
print("Loading model with: quantization_config=", quant_config, ", device_map=", device_map)

# Load model with 8-bit quantization and CPU offloading
try:
    model = Llama4ForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map=device_map,
        quantization_config=quant_config,
        attn_implementation="flex_attention"
    )
except Exception as e:
    print(f"Model loading failed: {str(e)}")
    raise

# Resize token embeddings if pad token was added
model.resize_token_embeddings(len(tokenizer))

# Initialize Accelerator
accelerator = Accelerator()
model = accelerator.prepare(model)

# Initialize analyzer
analyzer = HealthcareFraudAnalyzer(model, tokenizer, accelerator)

# Training function
def fine_tune_model(training_data_file, epochs=1, batch_size=2):
    try:
        dataset = datasets.load_dataset('json', data_files=training_data_file)
        dataset = dataset['train']

        lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )

        model = prepare_model_for_kbit_training(model)
        model = get_peft_model(model, lora_config)

        training_args = {
            "output_dir": "./results",
            "num_train_epochs": int(epochs),
            "per_device_train_batch_size": int(batch_size),
            "gradient_accumulation_steps": 8,
            "optim": "adamw_torch",
            "save_steps": 500,
            "logging_steps": 100,
            "learning_rate": 2e-4,
            "fp16": True,
            "max_grad_norm": 0.3,
            "warmup_ratio": 0.03,
            "lr_scheduler_type": "cosine"
        }

        trainer = accelerator.prepare(
            datasets.Trainer(
                model=model,
                args=datasets.TrainingArguments(**training_args),
                train_dataset=dataset,
            )
        )

        trainer.train()
        model.save_pretrained("./fine_tuned_model")
        return f"Training completed with {len(dataset)} examples!"
    except Exception as e:
        return f"Training failed: {str(e)}"

# Document analysis function
def analyze_document(pdf_file):
    try:
        with pdfplumber.open(pdf_file) as pdf:
            text = ""
            for page in pdf.pages:
                text += page.extract_text() or ""
        
        sentences = sent_tokenize(text)
        fraud_indicators = analyzer.analyze_document(sentences)
        
        if not fraud_indicators:
            return "No fraud indicators detected."
        
        report = "Potential Fraud Indicators Detected:\n"
        for indicator in fraud_indicators:
            report += f"- {indicator['sentence']}\n  Reason: {indicator['reason']}\n  Confidence: {indicator['confidence']:.2f}\n"
        return report
    except Exception as e:
        return f"Analysis failed: {str(e)}"

# Gradio interface
with gr.Blocks(theme=gr.themes.Default()) as demo:
    gr.Markdown("# Llama 4 Healthcare Fraud Detection")
    
    with gr.Tab("Fine-Tune Model"):
        training_data = gr.File(label="Upload Training JSON File")
        epochs = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Epochs")
        batch_size = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Batch Size")
        train_button = gr.Button("Fine-Tune")
        train_output = gr.Textbox(label="Training Output")
        train_button.click(
            fn=fine_tune_model,
            inputs=[training_data, epochs, batch_size],
            outputs=train_output
        )
    
    with gr.Tab("Analyze Document"):
        pdf_input = gr.File(label="Upload PDF Document")
        analyze_button = gr.Button("Analyze")
        analysis_output = gr.Textbox(label="Analysis Results")
        analyze_button.click(
            fn=analyze_document,
            inputs=pdf_input,
            outputs=analysis_output
        )

demo.launch(server_name="0.0.0.0", server_port=7860)