Cylanoid commited on
Commit
b2c8265
·
1 Parent(s): f5ef606

Use device_map='auto' + offload_folder to avoid OOM

Browse files
Files changed (1) hide show
  1. app.py +48 -32
app.py CHANGED
@@ -1,21 +1,25 @@
1
  # app.py
2
 
3
- import gradio as gr
4
- from transformers import AutoTokenizer, Llama4ForConditionalGeneration, BitsAndBytesConfig
5
- import datasets
6
- import torch
7
  import os
 
 
 
 
8
  import pdfplumber
9
  import nltk
10
  from nltk.tokenize import sent_tokenize
11
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
 
 
12
  from accelerate import Accelerator
 
13
  import huggingface_hub
 
14
  from document_analyzer import HealthcareFraudAnalyzer
15
 
16
- print("Running updated app.py with CPU offloading (version: 2025-04-21 v3)")
17
 
18
- # — Ensure NLTK punkt is available
19
  try:
20
  nltk.data.find('tokenizers/punkt')
21
  except LookupError:
@@ -24,44 +28,46 @@ except LookupError:
24
  # — Authenticate with Hugging Face
25
  LLAMA = os.getenv("LLama")
26
  if not LLAMA:
27
- raise ValueError("LLama token not found. Set it in environment as 'LLama'.")
28
  huggingface_hub.login(token=LLAMA)
29
 
30
- # — Tokenizer
31
  MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
32
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
33
  if tokenizer.pad_token is None:
34
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
35
 
36
- # — Quantization + CPU off‑load config
37
  quant_config = BitsAndBytesConfig(
38
  load_in_8bit=True,
39
  llm_int8_enable_fp32_cpu_offload=True
40
  )
41
 
42
- print("Loading model with: quantization_config=", quant_config, ", device_map=auto")
43
  model = Llama4ForConditionalGeneration.from_pretrained(
44
  MODEL_ID,
45
  torch_dtype=torch.bfloat16,
46
- device_map="auto",
47
  quantization_config=quant_config,
48
- offload_folder="./offload"
49
  )
50
 
51
- # — Resize embeddings if pad was added
52
  model.resize_token_embeddings(len(tokenizer))
53
 
54
- # — Accelerator prep
55
  accelerator = Accelerator()
56
  model = accelerator.prepare(model)
57
 
58
- # — Analyzer instance
59
  analyzer = HealthcareFraudAnalyzer(model, tokenizer, accelerator)
60
 
61
  # — Fine-tune function
62
  def fine_tune_model(training_data_file, epochs=1, batch_size=2):
63
  try:
64
- dataset = datasets.load_dataset('json', data_files=training_data_file)['train']
 
 
65
  lora_cfg = LoraConfig(
66
  r=16,
67
  lora_alpha=32,
@@ -70,9 +76,12 @@ def fine_tune_model(training_data_file, epochs=1, batch_size=2):
70
  bias="none",
71
  task_type="CAUSAL_LM"
72
  )
 
 
73
  local_model = prepare_model_for_kbit_training(model)
74
  local_model = get_peft_model(local_model, lora_cfg)
75
 
 
76
  args = {
77
  "output_dir": "./results",
78
  "num_train_epochs": int(epochs),
@@ -87,48 +96,55 @@ def fine_tune_model(training_data_file, epochs=1, batch_size=2):
87
  "warmup_ratio": 0.03,
88
  "lr_scheduler_type": "cosine"
89
  }
 
90
  trainer = accelerator.prepare(
91
  datasets.Trainer(
92
  model=local_model,
93
  args=datasets.TrainingArguments(**args),
94
- train_dataset=dataset
95
  )
96
  )
 
97
  trainer.train()
98
  local_model.save_pretrained("./fine_tuned_model")
99
- return f"Training completed with {len(dataset)} examples!"
100
  except Exception as e:
101
  return f"Training failed: {e}"
102
 
103
  # — PDF analysis function
104
  def analyze_document(pdf_file):
105
  try:
106
- with pdfplumber.open(pdf_file) as pdf:
107
- text = "".join(page.extract_text() or "" for page in pdf.pages)
 
 
 
108
  sentences = sent_tokenize(text)
109
- fraud_indicators = analyzer.analyze_document(sentences)
110
- if not fraud_indicators:
 
111
  return "No fraud indicators detected."
112
- report = "Potential Fraud Indicators Detected:\n"
113
- for ind in fraud_indicators:
 
114
  report += (
115
- f"- {ind['sentence']}\n"
116
- f" Reason: {ind['reason']}\n"
117
- f" Confidence: {ind['confidence']:.2f}\n"
118
  )
119
- return report
120
  except Exception as e:
121
  return f"Analysis failed: {e}"
122
 
123
- # — Gradio UI
124
  with gr.Blocks(theme=gr.themes.Default()) as demo:
125
  gr.Markdown("# Llama 4 Healthcare Fraud Detection")
126
 
127
- with gr.Tab("FineTune Model"):
128
  training_data = gr.File(label="Upload Training JSON File")
129
  epochs = gr.Slider(1, 10, value=1, step=1, label="Epochs")
130
  batch_size = gr.Slider(1, 4, value=2, step=1, label="Batch Size")
131
- train_button = gr.Button("FineTune")
132
  train_output = gr.Textbox(label="Training Output")
133
  train_button.click(
134
  fn=fine_tune_model,
 
1
  # app.py
2
 
 
 
 
 
3
  import os
4
+ import json
5
+ import re
6
+
7
+ import gradio as gr
8
  import pdfplumber
9
  import nltk
10
  from nltk.tokenize import sent_tokenize
11
+ from transformers import AutoTokenizer, Llama4ForConditionalGeneration, BitsAndBytesConfig
12
+ import datasets
13
+ import torch
14
  from accelerate import Accelerator
15
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
16
  import huggingface_hub
17
+
18
  from document_analyzer import HealthcareFraudAnalyzer
19
 
20
+ print("Running updated app.py with CPU offloading (version: 2025-04-22 v1)")
21
 
22
+ # — Ensure NLTK punkt tokenizer is available
23
  try:
24
  nltk.data.find('tokenizers/punkt')
25
  except LookupError:
 
28
  # — Authenticate with Hugging Face
29
  LLAMA = os.getenv("LLama")
30
  if not LLAMA:
31
+ raise ValueError("LLama token not found. Please set it as 'LLama' in your environment.")
32
  huggingface_hub.login(token=LLAMA)
33
 
34
+ # — Model and tokenizer setup
35
  MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
36
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
37
  if tokenizer.pad_token is None:
38
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
39
 
40
+ # — BitsAndBytes quantization + CPU off‑load config
41
  quant_config = BitsAndBytesConfig(
42
  load_in_8bit=True,
43
  llm_int8_enable_fp32_cpu_offload=True
44
  )
45
 
46
+ print("Loading model with 8-bit quantization, CPU offload, and automatic device mapping")
47
  model = Llama4ForConditionalGeneration.from_pretrained(
48
  MODEL_ID,
49
  torch_dtype=torch.bfloat16,
50
+ device_map="auto", # let Accelerate decide which layers go to GPU vs. CPU
51
  quantization_config=quant_config,
52
+ offload_folder="./offload" # spill CPU‑offloaded weights here
53
  )
54
 
55
+ # — Resize embeddings if we added a pad token
56
  model.resize_token_embeddings(len(tokenizer))
57
 
58
+ # — Prepare with Accelerate
59
  accelerator = Accelerator()
60
  model = accelerator.prepare(model)
61
 
62
+ # — Initialize the fraud analyzer
63
  analyzer = HealthcareFraudAnalyzer(model, tokenizer, accelerator)
64
 
65
  # — Fine-tune function
66
  def fine_tune_model(training_data_file, epochs=1, batch_size=2):
67
  try:
68
+ ds = datasets.load_dataset('json', data_files=training_data_file)['train']
69
+
70
+ # LoRA configuration
71
  lora_cfg = LoraConfig(
72
  r=16,
73
  lora_alpha=32,
 
76
  bias="none",
77
  task_type="CAUSAL_LM"
78
  )
79
+
80
+ # Prepare for k-bit training
81
  local_model = prepare_model_for_kbit_training(model)
82
  local_model = get_peft_model(local_model, lora_cfg)
83
 
84
+ # Training arguments
85
  args = {
86
  "output_dir": "./results",
87
  "num_train_epochs": int(epochs),
 
96
  "warmup_ratio": 0.03,
97
  "lr_scheduler_type": "cosine"
98
  }
99
+
100
  trainer = accelerator.prepare(
101
  datasets.Trainer(
102
  model=local_model,
103
  args=datasets.TrainingArguments(**args),
104
+ train_dataset=ds
105
  )
106
  )
107
+
108
  trainer.train()
109
  local_model.save_pretrained("./fine_tuned_model")
110
+ return f"Training completed on {len(ds)} examples."
111
  except Exception as e:
112
  return f"Training failed: {e}"
113
 
114
  # — PDF analysis function
115
  def analyze_document(pdf_file):
116
  try:
117
+ text = ""
118
+ with pdfplumber.open(pdf_file.name) as pdf:
119
+ for page in pdf.pages:
120
+ text += page.extract_text() or ""
121
+
122
  sentences = sent_tokenize(text)
123
+ results = analyzer.analyze_document(sentences)
124
+
125
+ if not results:
126
  return "No fraud indicators detected."
127
+
128
+ report = "Potential Fraud Indicators Detected:\n\n"
129
+ for item in results:
130
  report += (
131
+ f"- Sentence: {item['sentence']}\n"
132
+ f" Reason: {item['reason']}\n"
133
+ f" Confidence: {item['confidence']:.2f}\n\n"
134
  )
135
+ return report.strip()
136
  except Exception as e:
137
  return f"Analysis failed: {e}"
138
 
139
+ # — Gradio Interface
140
  with gr.Blocks(theme=gr.themes.Default()) as demo:
141
  gr.Markdown("# Llama 4 Healthcare Fraud Detection")
142
 
143
+ with gr.Tab("Fine-Tune Model"):
144
  training_data = gr.File(label="Upload Training JSON File")
145
  epochs = gr.Slider(1, 10, value=1, step=1, label="Epochs")
146
  batch_size = gr.Slider(1, 4, value=2, step=1, label="Batch Size")
147
+ train_button = gr.Button("Fine-Tune")
148
  train_output = gr.Textbox(label="Training Output")
149
  train_button.click(
150
  fn=fine_tune_model,