# app.py import os import gradio as gr import pdfplumber import nltk from nltk.tokenize import sent_tokenize from transformers import AutoTokenizer, Llama4ForConditionalGeneration, BitsAndBytesConfig import datasets import torch from accelerate import Accelerator from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training import huggingface_hub from document_analyzer import HealthcareFraudAnalyzer print("Running updated app.py with restricted GPU usage (version: 2025-04-22 v2)") # — Ensure NLTK punkt tokenizer is available try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt') # — Authenticate with Hugging Face LLAMA = os.getenv("LLama") if not LLAMA: raise ValueError("LLama token not found. Please set it as 'LLama' in your environment.") huggingface_hub.login(token=LLAMA) # — Model and tokenizer setup MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # — BitsAndBytes quantization + CPU off‑load config quant_config = BitsAndBytesConfig( load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True ) print("Loading model with 8-bit quantization, CPU offload, auto device mapping + max_memory cap") model = Llama4ForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", max_memory={ # cap GPU usage to ~11 GiB 0: "11GiB", "cpu": "200GiB" }, quantization_config=quant_config, offload_folder="./offload" ) # — Resize embeddings if we added a pad token model.resize_token_embeddings(len(tokenizer)) # — Prepare with Accelerate accelerator = Accelerator() model = accelerator.prepare(model) # — Initialize the fraud analyzer analyzer = HealthcareFraudAnalyzer(model, tokenizer, accelerator) # — Fine-tune function def fine_tune_model(training_data_file, epochs=1, batch_size=2): try: ds = datasets.load_dataset('json', data_files=training_data_file)['train'] # LoRA configuration lora_cfg = LoraConfig( r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) # Prepare for k-bit training local_model = prepare_model_for_kbit_training(model) local_model = get_peft_model(local_model, lora_cfg) # Training arguments args = { "output_dir": "./results", "num_train_epochs": int(epochs), "per_device_train_batch_size": int(batch_size), "gradient_accumulation_steps": 8, "optim": "adamw_torch", "save_steps": 500, "logging_steps": 100, "learning_rate": 2e-4, "fp16": True, "max_grad_norm": 0.3, "warmup_ratio": 0.03, "lr_scheduler_type": "cosine" } trainer = accelerator.prepare( datasets.Trainer( model=local_model, args=datasets.TrainingArguments(**args), train_dataset=ds ) ) trainer.train() local_model.save_pretrained("./fine_tuned_model") return f"Training completed on {len(ds)} examples." except Exception as e: return f"Training failed: {e}" # — PDF analysis function def analyze_document(pdf_file): try: text = "" with pdfplumber.open(pdf_file.name) as pdf: for page in pdf.pages: text += page.extract_text() or "" sentences = sent_tokenize(text) results = analyzer.analyze_document(sentences) if not results: return "No fraud indicators detected." report = "Potential Fraud Indicators Detected:\n\n" for item in results: report += ( f"- Sentence: {item['sentence']}\n" f" Reason: {item['reason']}\n" f" Confidence: {item['confidence']:.2f}\n\n" ) return report.strip() except Exception as e: return f"Analysis failed: {e}" # — Gradio Interface with gr.Blocks(theme=gr.themes.Default()) as demo: gr.Markdown("# Llama 4 Healthcare Fraud Detection") with gr.Tab("Fine-Tune Model"): training_data = gr.File(label="Upload Training JSON File") epochs = gr.Slider(1, 10, value=1, step=1, label="Epochs") batch_size = gr.Slider(1, 4, value=2, step=1, label="Batch Size") train_button = gr.Button("Fine-Tune") train_output = gr.Textbox(label="Training Output") train_button.click( fn=fine_tune_model, inputs=[training_data, epochs, batch_size], outputs=train_output ) with gr.Tab("Analyze Document"): pdf_input = gr.File(label="Upload PDF Document") analyze_button = gr.Button("Analyze") analysis_output = gr.Textbox(label="Analysis Results") analyze_button.click( fn=analyze_document, inputs=pdf_input, outputs=analysis_output ) demo.launch(server_name="0.0.0.0", server_port=7860)