Cylanoid's picture
we w
fab7ed8
raw
history blame
5.93 kB
# app.py
# Gradio app for Llama 4 Maverick healthcare fraud detection (text-only with CPU offloading)
import gradio as gr
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, BitsAndBytesConfig
import datasets
import torch
import json
import os
import pdfplumber
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate import Accelerator
import huggingface_hub
import re
import nltk
from nltk.tokenize import sent_tokenize
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
# Import the HealthcareFraudAnalyzer
from document_analyzer import HealthcareFraudAnalyzer
# Debug: Confirm file version
print("Running updated app.py with CPU offloading (version: 2025-04-21 v3)")
# Debug: Print environment variables
print("Environment variables:", dict(os.environ))
# Retrieve the token from secrets
LLama = os.getenv("LLama")
if not LLama:
raise ValueError("LLama token not found. Set it in Hugging Face Space secrets as 'LLama'.")
# Debug: Print token (first 5 chars)
print(f"Retrieved LLama token: {LLama[:5]}...")
# Authenticate with Hugging Face
huggingface_hub.login(token=LLama)
# Model setup
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Explicit quantization configuration
quant_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=True
)
# Custom device map for CPU offloading (more layers to CPU)
device_map = {
"model.embed_tokens": 0,
"model.layers.0-10": 0, # First 11 layers on GPU
"model.layers.11-31": "cpu", # Remaining layers on CPU
"model.norm": 0,
"lm_head": 0
}
# Debug: Confirm offloading settings
print("Loading model with: quantization_config=", quant_config, ", device_map=", device_map)
# Load model with 8-bit quantization and CPU offloading
try:
model = Llama4ForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map=device_map,
quantization_config=quant_config,
attn_implementation="flex_attention"
)
except Exception as e:
print(f"Model loading failed: {str(e)}")
raise
# Resize token embeddings if pad token was added
model.resize_token_embeddings(len(tokenizer))
# Initialize Accelerator
accelerator = Accelerator()
model = accelerator.prepare(model)
# Initialize analyzer
analyzer = HealthcareFraudAnalyzer(model, tokenizer, accelerator)
# Training function
def fine_tune_model(training_data_file, epochs=1, batch_size=2):
try:
dataset = datasets.load_dataset('json', data_files=training_data_file)
dataset = dataset['train']
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
training_args = {
"output_dir": "./results",
"num_train_epochs": int(epochs),
"per_device_train_batch_size": int(batch_size),
"gradient_accumulation_steps": 8,
"optim": "adamw_torch",
"save_steps": 500,
"logging_steps": 100,
"learning_rate": 2e-4,
"fp16": True,
"max_grad_norm": 0.3,
"warmup_ratio": 0.03,
"lr_scheduler_type": "cosine"
}
trainer = accelerator.prepare(
datasets.Trainer(
model=model,
args=datasets.TrainingArguments(**training_args),
train_dataset=dataset,
)
)
trainer.train()
model.save_pretrained("./fine_tuned_model")
return f"Training completed with {len(dataset)} examples!"
except Exception as e:
return f"Training failed: {str(e)}"
# Document analysis function
def analyze_document(pdf_file):
try:
with pdfplumber.open(pdf_file) as pdf:
text = ""
for page in pdf.pages:
text += page.extract_text() or ""
sentences = sent_tokenize(text)
fraud_indicators = analyzer.analyze_document(sentences)
if not fraud_indicators:
return "No fraud indicators detected."
report = "Potential Fraud Indicators Detected:\n"
for indicator in fraud_indicators:
report += f"- {indicator['sentence']}\n Reason: {indicator['reason']}\n Confidence: {indicator['confidence']:.2f}\n"
return report
except Exception as e:
return f"Analysis failed: {str(e)}"
# Gradio interface
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown("# Llama 4 Healthcare Fraud Detection")
with gr.Tab("Fine-Tune Model"):
training_data = gr.File(label="Upload Training JSON File")
epochs = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Epochs")
batch_size = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Batch Size")
train_button = gr.Button("Fine-Tune")
train_output = gr.Textbox(label="Training Output")
train_button.click(
fn=fine_tune_model,
inputs=[training_data, epochs, batch_size],
outputs=train_output
)
with gr.Tab("Analyze Document"):
pdf_input = gr.File(label="Upload PDF Document")
analyze_button = gr.Button("Analyze")
analysis_output = gr.Textbox(label="Analysis Results")
analyze_button.click(
fn=analyze_document,
inputs=pdf_input,
outputs=analysis_output
)
demo.launch(server_name="0.0.0.0", server_port=7860)