Use device_map='auto' + offload_folder to avoid OOM
Browse files
app.py
CHANGED
@@ -1,21 +1,25 @@
|
|
1 |
# app.py
|
2 |
|
3 |
-
import gradio as gr
|
4 |
-
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, BitsAndBytesConfig
|
5 |
-
import datasets
|
6 |
-
import torch
|
7 |
import os
|
|
|
|
|
|
|
|
|
8 |
import pdfplumber
|
9 |
import nltk
|
10 |
from nltk.tokenize import sent_tokenize
|
11 |
-
from
|
|
|
|
|
12 |
from accelerate import Accelerator
|
|
|
13 |
import huggingface_hub
|
|
|
14 |
from document_analyzer import HealthcareFraudAnalyzer
|
15 |
|
16 |
-
print("Running updated app.py with CPU offloading (version: 2025-04-
|
17 |
|
18 |
-
# — Ensure NLTK punkt is available
|
19 |
try:
|
20 |
nltk.data.find('tokenizers/punkt')
|
21 |
except LookupError:
|
@@ -24,44 +28,46 @@ except LookupError:
|
|
24 |
# — Authenticate with Hugging Face
|
25 |
LLAMA = os.getenv("LLama")
|
26 |
if not LLAMA:
|
27 |
-
raise ValueError("LLama token not found.
|
28 |
huggingface_hub.login(token=LLAMA)
|
29 |
|
30 |
-
# —
|
31 |
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
|
32 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
33 |
if tokenizer.pad_token is None:
|
34 |
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
35 |
|
36 |
-
# —
|
37 |
quant_config = BitsAndBytesConfig(
|
38 |
load_in_8bit=True,
|
39 |
llm_int8_enable_fp32_cpu_offload=True
|
40 |
)
|
41 |
|
42 |
-
print("Loading model with
|
43 |
model = Llama4ForConditionalGeneration.from_pretrained(
|
44 |
MODEL_ID,
|
45 |
torch_dtype=torch.bfloat16,
|
46 |
-
device_map="auto",
|
47 |
quantization_config=quant_config,
|
48 |
-
offload_folder="./offload"
|
49 |
)
|
50 |
|
51 |
-
# — Resize embeddings if pad
|
52 |
model.resize_token_embeddings(len(tokenizer))
|
53 |
|
54 |
-
# —
|
55 |
accelerator = Accelerator()
|
56 |
model = accelerator.prepare(model)
|
57 |
|
58 |
-
# —
|
59 |
analyzer = HealthcareFraudAnalyzer(model, tokenizer, accelerator)
|
60 |
|
61 |
# — Fine-tune function
|
62 |
def fine_tune_model(training_data_file, epochs=1, batch_size=2):
|
63 |
try:
|
64 |
-
|
|
|
|
|
65 |
lora_cfg = LoraConfig(
|
66 |
r=16,
|
67 |
lora_alpha=32,
|
@@ -70,9 +76,12 @@ def fine_tune_model(training_data_file, epochs=1, batch_size=2):
|
|
70 |
bias="none",
|
71 |
task_type="CAUSAL_LM"
|
72 |
)
|
|
|
|
|
73 |
local_model = prepare_model_for_kbit_training(model)
|
74 |
local_model = get_peft_model(local_model, lora_cfg)
|
75 |
|
|
|
76 |
args = {
|
77 |
"output_dir": "./results",
|
78 |
"num_train_epochs": int(epochs),
|
@@ -87,48 +96,55 @@ def fine_tune_model(training_data_file, epochs=1, batch_size=2):
|
|
87 |
"warmup_ratio": 0.03,
|
88 |
"lr_scheduler_type": "cosine"
|
89 |
}
|
|
|
90 |
trainer = accelerator.prepare(
|
91 |
datasets.Trainer(
|
92 |
model=local_model,
|
93 |
args=datasets.TrainingArguments(**args),
|
94 |
-
train_dataset=
|
95 |
)
|
96 |
)
|
|
|
97 |
trainer.train()
|
98 |
local_model.save_pretrained("./fine_tuned_model")
|
99 |
-
return f"Training completed
|
100 |
except Exception as e:
|
101 |
return f"Training failed: {e}"
|
102 |
|
103 |
# — PDF analysis function
|
104 |
def analyze_document(pdf_file):
|
105 |
try:
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
108 |
sentences = sent_tokenize(text)
|
109 |
-
|
110 |
-
|
|
|
111 |
return "No fraud indicators detected."
|
112 |
-
|
113 |
-
|
|
|
114 |
report += (
|
115 |
-
f"- {
|
116 |
-
f" Reason: {
|
117 |
-
f" Confidence: {
|
118 |
)
|
119 |
-
return report
|
120 |
except Exception as e:
|
121 |
return f"Analysis failed: {e}"
|
122 |
|
123 |
-
# — Gradio
|
124 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
125 |
gr.Markdown("# Llama 4 Healthcare Fraud Detection")
|
126 |
|
127 |
-
with gr.Tab("Fine
|
128 |
training_data = gr.File(label="Upload Training JSON File")
|
129 |
epochs = gr.Slider(1, 10, value=1, step=1, label="Epochs")
|
130 |
batch_size = gr.Slider(1, 4, value=2, step=1, label="Batch Size")
|
131 |
-
train_button = gr.Button("Fine
|
132 |
train_output = gr.Textbox(label="Training Output")
|
133 |
train_button.click(
|
134 |
fn=fine_tune_model,
|
|
|
1 |
# app.py
|
2 |
|
|
|
|
|
|
|
|
|
3 |
import os
|
4 |
+
import json
|
5 |
+
import re
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
import pdfplumber
|
9 |
import nltk
|
10 |
from nltk.tokenize import sent_tokenize
|
11 |
+
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, BitsAndBytesConfig
|
12 |
+
import datasets
|
13 |
+
import torch
|
14 |
from accelerate import Accelerator
|
15 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
16 |
import huggingface_hub
|
17 |
+
|
18 |
from document_analyzer import HealthcareFraudAnalyzer
|
19 |
|
20 |
+
print("Running updated app.py with CPU offloading (version: 2025-04-22 v1)")
|
21 |
|
22 |
+
# — Ensure NLTK punkt tokenizer is available
|
23 |
try:
|
24 |
nltk.data.find('tokenizers/punkt')
|
25 |
except LookupError:
|
|
|
28 |
# — Authenticate with Hugging Face
|
29 |
LLAMA = os.getenv("LLama")
|
30 |
if not LLAMA:
|
31 |
+
raise ValueError("LLama token not found. Please set it as 'LLama' in your environment.")
|
32 |
huggingface_hub.login(token=LLAMA)
|
33 |
|
34 |
+
# — Model and tokenizer setup
|
35 |
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
|
36 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
37 |
if tokenizer.pad_token is None:
|
38 |
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
39 |
|
40 |
+
# — BitsAndBytes quantization + CPU off‑load config
|
41 |
quant_config = BitsAndBytesConfig(
|
42 |
load_in_8bit=True,
|
43 |
llm_int8_enable_fp32_cpu_offload=True
|
44 |
)
|
45 |
|
46 |
+
print("Loading model with 8-bit quantization, CPU offload, and automatic device mapping")
|
47 |
model = Llama4ForConditionalGeneration.from_pretrained(
|
48 |
MODEL_ID,
|
49 |
torch_dtype=torch.bfloat16,
|
50 |
+
device_map="auto", # let Accelerate decide which layers go to GPU vs. CPU
|
51 |
quantization_config=quant_config,
|
52 |
+
offload_folder="./offload" # spill CPU‑offloaded weights here
|
53 |
)
|
54 |
|
55 |
+
# — Resize embeddings if we added a pad token
|
56 |
model.resize_token_embeddings(len(tokenizer))
|
57 |
|
58 |
+
# — Prepare with Accelerate
|
59 |
accelerator = Accelerator()
|
60 |
model = accelerator.prepare(model)
|
61 |
|
62 |
+
# — Initialize the fraud analyzer
|
63 |
analyzer = HealthcareFraudAnalyzer(model, tokenizer, accelerator)
|
64 |
|
65 |
# — Fine-tune function
|
66 |
def fine_tune_model(training_data_file, epochs=1, batch_size=2):
|
67 |
try:
|
68 |
+
ds = datasets.load_dataset('json', data_files=training_data_file)['train']
|
69 |
+
|
70 |
+
# LoRA configuration
|
71 |
lora_cfg = LoraConfig(
|
72 |
r=16,
|
73 |
lora_alpha=32,
|
|
|
76 |
bias="none",
|
77 |
task_type="CAUSAL_LM"
|
78 |
)
|
79 |
+
|
80 |
+
# Prepare for k-bit training
|
81 |
local_model = prepare_model_for_kbit_training(model)
|
82 |
local_model = get_peft_model(local_model, lora_cfg)
|
83 |
|
84 |
+
# Training arguments
|
85 |
args = {
|
86 |
"output_dir": "./results",
|
87 |
"num_train_epochs": int(epochs),
|
|
|
96 |
"warmup_ratio": 0.03,
|
97 |
"lr_scheduler_type": "cosine"
|
98 |
}
|
99 |
+
|
100 |
trainer = accelerator.prepare(
|
101 |
datasets.Trainer(
|
102 |
model=local_model,
|
103 |
args=datasets.TrainingArguments(**args),
|
104 |
+
train_dataset=ds
|
105 |
)
|
106 |
)
|
107 |
+
|
108 |
trainer.train()
|
109 |
local_model.save_pretrained("./fine_tuned_model")
|
110 |
+
return f"Training completed on {len(ds)} examples."
|
111 |
except Exception as e:
|
112 |
return f"Training failed: {e}"
|
113 |
|
114 |
# — PDF analysis function
|
115 |
def analyze_document(pdf_file):
|
116 |
try:
|
117 |
+
text = ""
|
118 |
+
with pdfplumber.open(pdf_file.name) as pdf:
|
119 |
+
for page in pdf.pages:
|
120 |
+
text += page.extract_text() or ""
|
121 |
+
|
122 |
sentences = sent_tokenize(text)
|
123 |
+
results = analyzer.analyze_document(sentences)
|
124 |
+
|
125 |
+
if not results:
|
126 |
return "No fraud indicators detected."
|
127 |
+
|
128 |
+
report = "Potential Fraud Indicators Detected:\n\n"
|
129 |
+
for item in results:
|
130 |
report += (
|
131 |
+
f"- Sentence: {item['sentence']}\n"
|
132 |
+
f" Reason: {item['reason']}\n"
|
133 |
+
f" Confidence: {item['confidence']:.2f}\n\n"
|
134 |
)
|
135 |
+
return report.strip()
|
136 |
except Exception as e:
|
137 |
return f"Analysis failed: {e}"
|
138 |
|
139 |
+
# — Gradio Interface
|
140 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
141 |
gr.Markdown("# Llama 4 Healthcare Fraud Detection")
|
142 |
|
143 |
+
with gr.Tab("Fine-Tune Model"):
|
144 |
training_data = gr.File(label="Upload Training JSON File")
|
145 |
epochs = gr.Slider(1, 10, value=1, step=1, label="Epochs")
|
146 |
batch_size = gr.Slider(1, 4, value=2, step=1, label="Batch Size")
|
147 |
+
train_button = gr.Button("Fine-Tune")
|
148 |
train_output = gr.Textbox(label="Training Output")
|
149 |
train_button.click(
|
150 |
fn=fine_tune_model,
|