Cylanoid's picture
Fix max_memory keys (use integer 0) or drop max_memory
a7aeb40
# app.py
import os
import gradio as gr
import pdfplumber
import nltk
from nltk.tokenize import sent_tokenize
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, BitsAndBytesConfig
import datasets
import torch
from accelerate import Accelerator
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import huggingface_hub
from document_analyzer import HealthcareFraudAnalyzer
print("Running updated app.py with restricted GPU usage (version: 2025-04-22 v2)")
# — Ensure NLTK punkt tokenizer is available
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
# — Authenticate with Hugging Face
LLAMA = os.getenv("LLama")
if not LLAMA:
raise ValueError("LLama token not found. Please set it as 'LLama' in your environment.")
huggingface_hub.login(token=LLAMA)
# — Model and tokenizer setup
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# — BitsAndBytes quantization + CPU off‑load config
quant_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=True
)
print("Loading model with 8-bit quantization, CPU offload, auto device mapping + max_memory cap")
model = Llama4ForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
max_memory={ # cap GPU usage to ~11 GiB
0: "11GiB",
"cpu": "200GiB"
},
quantization_config=quant_config,
offload_folder="./offload"
)
# — Resize embeddings if we added a pad token
model.resize_token_embeddings(len(tokenizer))
# — Prepare with Accelerate
accelerator = Accelerator()
model = accelerator.prepare(model)
# — Initialize the fraud analyzer
analyzer = HealthcareFraudAnalyzer(model, tokenizer, accelerator)
# — Fine-tune function
def fine_tune_model(training_data_file, epochs=1, batch_size=2):
try:
ds = datasets.load_dataset('json', data_files=training_data_file)['train']
# LoRA configuration
lora_cfg = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# Prepare for k-bit training
local_model = prepare_model_for_kbit_training(model)
local_model = get_peft_model(local_model, lora_cfg)
# Training arguments
args = {
"output_dir": "./results",
"num_train_epochs": int(epochs),
"per_device_train_batch_size": int(batch_size),
"gradient_accumulation_steps": 8,
"optim": "adamw_torch",
"save_steps": 500,
"logging_steps": 100,
"learning_rate": 2e-4,
"fp16": True,
"max_grad_norm": 0.3,
"warmup_ratio": 0.03,
"lr_scheduler_type": "cosine"
}
trainer = accelerator.prepare(
datasets.Trainer(
model=local_model,
args=datasets.TrainingArguments(**args),
train_dataset=ds
)
)
trainer.train()
local_model.save_pretrained("./fine_tuned_model")
return f"Training completed on {len(ds)} examples."
except Exception as e:
return f"Training failed: {e}"
# — PDF analysis function
def analyze_document(pdf_file):
try:
text = ""
with pdfplumber.open(pdf_file.name) as pdf:
for page in pdf.pages:
text += page.extract_text() or ""
sentences = sent_tokenize(text)
results = analyzer.analyze_document(sentences)
if not results:
return "No fraud indicators detected."
report = "Potential Fraud Indicators Detected:\n\n"
for item in results:
report += (
f"- Sentence: {item['sentence']}\n"
f" Reason: {item['reason']}\n"
f" Confidence: {item['confidence']:.2f}\n\n"
)
return report.strip()
except Exception as e:
return f"Analysis failed: {e}"
# — Gradio Interface
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown("# Llama 4 Healthcare Fraud Detection")
with gr.Tab("Fine-Tune Model"):
training_data = gr.File(label="Upload Training JSON File")
epochs = gr.Slider(1, 10, value=1, step=1, label="Epochs")
batch_size = gr.Slider(1, 4, value=2, step=1, label="Batch Size")
train_button = gr.Button("Fine-Tune")
train_output = gr.Textbox(label="Training Output")
train_button.click(
fn=fine_tune_model,
inputs=[training_data, epochs, batch_size],
outputs=train_output
)
with gr.Tab("Analyze Document"):
pdf_input = gr.File(label="Upload PDF Document")
analyze_button = gr.Button("Analyze")
analysis_output = gr.Textbox(label="Analysis Results")
analyze_button.click(
fn=analyze_document,
inputs=pdf_input,
outputs=analysis_output
)
demo.launch(server_name="0.0.0.0", server_port=7860)