Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
# updated_app.py
|
2 |
-
# Enhanced Gradio app for Llama 4 Maverick healthcare fraud detection
|
3 |
|
4 |
import gradio as gr
|
5 |
-
from transformers import
|
6 |
import datasets
|
7 |
import torch
|
8 |
import json
|
@@ -39,9 +39,13 @@ huggingface_hub.login(token=LLama)
|
|
39 |
|
40 |
# Model setup
|
41 |
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
|
42 |
-
|
43 |
|
44 |
-
#
|
|
|
|
|
|
|
|
|
45 |
model = Llama4ForConditionalGeneration.from_pretrained(
|
46 |
MODEL_ID,
|
47 |
torch_dtype=torch.bfloat16,
|
@@ -157,18 +161,8 @@ def train_ui(files):
|
|
157 |
dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")
|
158 |
|
159 |
def tokenize_data(example):
|
160 |
-
|
161 |
-
|
162 |
-
"role": "user",
|
163 |
-
"content": [{"type": "text", "text": example['input']}]
|
164 |
-
},
|
165 |
-
{
|
166 |
-
"role": "assistant",
|
167 |
-
"content": [{"type": "text", "text": example['output']}]
|
168 |
-
}
|
169 |
-
]
|
170 |
-
formatted_text = processor.apply_chat_template(messages, add_generation_prompt=False)
|
171 |
-
inputs = processor(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
|
172 |
inputs["labels"] = inputs["input_ids"].clone()
|
173 |
return {k: v.squeeze(0) for k, v in inputs.items()}
|
174 |
|
@@ -208,7 +202,7 @@ def train_ui(files):
|
|
208 |
|
209 |
trainer.train()
|
210 |
model.save_pretrained("./fine_tuned_llama4_healthcare")
|
211 |
-
|
212 |
return f"Training completed with {len(tokenized_dataset)} examples! Model saved to ./fine_tuned_llama4_healthcare"
|
213 |
|
214 |
except Exception as e:
|
@@ -232,7 +226,7 @@ def analyze_document_ui(files):
|
|
232 |
if not raw_text:
|
233 |
return "Error: Could not extract text from the PDF. The file may be corrupt or contain only images."
|
234 |
|
235 |
-
analyzer = HealthcareFraudAnalyzer(model,
|
236 |
results = analyzer.analyze_document(raw_text)
|
237 |
return results["summary"]
|
238 |
|
|
|
1 |
# updated_app.py
|
2 |
+
# Enhanced Gradio app for Llama 4 Maverick healthcare fraud detection (text-only)
|
3 |
|
4 |
import gradio as gr
|
5 |
+
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
|
6 |
import datasets
|
7 |
import torch
|
8 |
import json
|
|
|
39 |
|
40 |
# Model setup
|
41 |
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
43 |
|
44 |
+
# Add padding token if it doesn't exist
|
45 |
+
if tokenizer.pad_token is None:
|
46 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
47 |
+
|
48 |
+
# Load model with 8-bit quantization to fit in 80 GB VRAM
|
49 |
model = Llama4ForConditionalGeneration.from_pretrained(
|
50 |
MODEL_ID,
|
51 |
torch_dtype=torch.bfloat16,
|
|
|
161 |
dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")
|
162 |
|
163 |
def tokenize_data(example):
|
164 |
+
formatted_text = f"<s>[INST] {example['input']} [/INST] {example['output']}</s>"
|
165 |
+
inputs = tokenizer(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
inputs["labels"] = inputs["input_ids"].clone()
|
167 |
return {k: v.squeeze(0) for k, v in inputs.items()}
|
168 |
|
|
|
202 |
|
203 |
trainer.train()
|
204 |
model.save_pretrained("./fine_tuned_llama4_healthcare")
|
205 |
+
tokenizer.save_pretrained("./fine_tuned_llama4_healthcare")
|
206 |
return f"Training completed with {len(tokenized_dataset)} examples! Model saved to ./fine_tuned_llama4_healthcare"
|
207 |
|
208 |
except Exception as e:
|
|
|
226 |
if not raw_text:
|
227 |
return "Error: Could not extract text from the PDF. The file may be corrupt or contain only images."
|
228 |
|
229 |
+
analyzer = HealthcareFraudAnalyzer(model, tokenizer)
|
230 |
results = analyzer.analyze_document(raw_text)
|
231 |
return results["summary"]
|
232 |
|