Cylanoid commited on
Commit
bc32b76
·
verified ·
1 Parent(s): cffe234

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -110
app.py CHANGED
@@ -1,134 +1,133 @@
1
- # app.py (corrected version)
2
-
3
- # Handle missing dependencies first
4
- try:
5
- import gradio as gr
6
- from transformers import LlamaForCausalLM, LlamaTokenizer, Trainer, TrainingArguments
7
- import datasets
8
- import torch
9
- import json
10
- import os
11
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
12
- from accelerate import Accelerator
13
- import bitsandbytes
14
- import sentencepiece # Added for Llama tokenizer
15
- except ImportError as e:
16
- missing_package = str(e).split("'")[-2] # Extract the missing package name
17
- if "accelerate" in missing_package:
18
- os.system(f'pip install "accelerate>=0.26.0"')
19
- elif "sentencepiece" in missing_package:
20
- os.system(f'pip install "sentencepiece"')
21
- else:
22
- os.system(f'pip install "{missing_package}"')
23
- # Re-import after installation
24
- import gradio as gr
25
- from transformers import LlamaForCausalLM, LlamaTokenizer, Trainer, TrainingArguments
26
- import datasets
27
- import torch
28
- import json
29
- import os
30
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
31
- from accelerate import Accelerator
32
- import bitsandbytes
33
- import sentencepiece
34
 
35
  # Model setup
36
- MODEL_ID = "meta-llama/Llama-2-7b-hf" # Use Llama-2-7b; switch to "meta-llama/Llama-3-8b-hf" for Llama 3
37
- tokenizer = LlamaTokenizer.from_pretrained(MODEL_ID)
38
 
39
- # Add padding token if it doesn't exist (required for Llama models)
40
  if tokenizer.pad_token is None:
41
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
 
42
 
43
- # Check if CUDA is available to enable Flash Attention 2
44
- use_flash_attention = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 # Ampere or newer (e.g., A100)
45
-
46
- # Load the model with optimizations for Llama
47
  model = LlamaForCausalLM.from_pretrained(
48
  MODEL_ID,
49
- torch_dtype=torch.bfloat16, # Better for A100 GPUs, falls back to float16 on CPU
50
  device_map="auto",
51
- use_flash_attention_2=use_flash_attention, # Only enable if GPU supports it
52
- load_in_8bit=True # Quantization for memory efficiency
53
  )
54
 
55
- # Prepare the model for training with LoRA (more memory-efficient)
56
  model = prepare_model_for_kbit_training(model)
57
-
58
- # LoRA configuration
59
  peft_config = LoraConfig(
60
- r=16, # Rank
61
- lora_alpha=32, # Alpha
62
- lora_dropout=0.05, # Dropout
63
  bias="none",
64
  task_type="CAUSAL_LM",
65
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"] # Attention modules for Llama
66
  )
67
-
68
  model = get_peft_model(model, peft_config)
69
- model.print_trainable_parameters() # Print percentage of trainable parameters
70
 
71
- # Function to process uploaded JSON and train
72
- def train_ui_tars(file):
73
  try:
74
- # Step 1: Load and preprocess the uploaded JSON file
75
- with open(file.name, "r", encoding="utf-8") as f:
76
- raw_data = json.load(f)
77
-
78
- # Extract training pairs or use flat structure
79
- training_data = raw_data.get("training_pairs", raw_data)
80
-
81
- # Save fixed JSON to avoid issues
82
- fixed_json_path = "fixed_fraud_data.json"
83
- with open(fixed_json_path, "w", encoding="utf-8") as f:
84
- json.dump(training_data, f, indent=4)
85
-
86
- # Load dataset
87
- dataset = datasets.load_dataset("json", data_files=fixed_json_path)
88
-
89
- # Step 2: Tokenize dataset with Llama-compatible context length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def tokenize_data(example):
91
- # Format input for Llama (instruction-following style)
92
  formatted_text = f"<s>[INST] {example['input']} [/INST] {example['output']}</s>"
93
- inputs = tokenizer(
94
- formatted_text,
95
- padding="max_length",
96
- truncation=True,
97
- max_length=2048, # Llama 2 context length; adjust to 8192 for Llama 3 if needed
98
- return_tensors="pt"
99
- )
100
  inputs["labels"] = inputs["input_ids"].clone()
101
  return {k: v.squeeze(0) for k, v in inputs.items()}
102
-
103
  tokenized_dataset = dataset["train"].map(tokenize_data, batched=True, remove_columns=dataset["train"].column_names)
104
-
105
- # Step 3: Training setup
106
  training_args = TrainingArguments(
107
- output_dir="./fine_tuned_llama",
108
- per_device_train_batch_size=4, # Increased for better efficiency
109
- gradient_accumulation_steps=8, # To handle larger effective batch size
110
- evaluation_strategy="no",
111
  save_strategy="epoch",
112
  save_total_limit=2,
113
- num_train_epochs=3,
114
  learning_rate=2e-5,
115
  weight_decay=0.01,
116
  logging_dir="./logs",
117
  logging_steps=10,
118
- bf16=True, # Use bfloat16 for A100 GPUs, falls back to float16 on CPU
119
- gradient_checkpointing=True, # Memory optimization
120
  optim="adamw_torch",
121
  warmup_steps=100,
122
  )
123
-
124
- # Custom data collator for Llama
125
  def custom_data_collator(features):
126
- batch = {
127
  "input_ids": torch.stack([f["input_ids"] for f in features]),
128
  "attention_mask": torch.stack([f["attention_mask"] for f in features]),
129
  "labels": torch.stack([f["labels"] for f in features]),
130
  }
131
- return batch
132
 
133
  trainer = Trainer(
134
  model=model,
@@ -136,28 +135,21 @@ def train_ui_tars(file):
136
  train_dataset=tokenized_dataset,
137
  data_collator=custom_data_collator,
138
  )
139
-
140
- # Step 4: Start training
141
  trainer.train()
142
-
143
- # Step 5: Save the model
144
- model.save_pretrained("./fine_tuned_llama")
145
- tokenizer.save_pretrained("./fine_tuned_llama")
146
-
147
- return "Training completed successfully! Model saved to ./fine_tuned_llama"
148
-
149
  except Exception as e:
150
- return f"Error: {str(e)}"
151
 
152
  # Gradio UI
153
- with gr.Blocks(title="Model Fine-Tuning Interface") as demo:
154
- gr.Markdown("# Llama Fraud Detection Fine-Tuning UI")
155
- gr.Markdown("Upload a JSON file with 'input' and 'output' pairs to fine-tune the Llama model on your fraud dataset.")
156
-
157
- file_input = gr.File(label="Upload Fraud Dataset (JSON)")
158
  train_button = gr.Button("Start Fine-Tuning")
159
- output = gr.Textbox(label="Training Status")
160
-
161
- train_button.click(fn=train_ui_tars, inputs=file_input, outputs=output)
162
 
163
  demo.launch()
 
1
+ # app.py
2
+
3
+ import gradio as gr
4
+ from transformers import LlamaForCausalLM, LlamaTokenizer
5
+ import datasets
6
+ import torch
7
+ import json
8
+ import os
9
+ import pdfplumber
10
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
11
+ from accelerate import Accelerator
12
+ import bitsandbytes
13
+ import sentencepiece
14
+ import huggingface_hub
15
+
16
+ # Retrieve HF_TOKEN from Hugging Face Space secrets
17
+ HF_TOKEN = os.getenv("HF_TOKEN")
18
+ if not HF_TOKEN:
19
+ raise ValueError("HF_TOKEN not found in environment variables. Please set it in Hugging Face Space secrets under 'Settings' > 'Secrets'.")
20
+
21
+ # Authenticate with Hugging Face
22
+ huggingface_hub.login(token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Model setup
25
+ MODEL_ID = "meta-llama/Llama-2-7b-hf"
26
+ tokenizer = LlamaTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
27
 
28
+ # Add padding token if it doesn't exist
29
  if tokenizer.pad_token is None:
30
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
31
+ model.resize_token_embeddings(len(tokenizer))
32
 
33
+ # Check CUDA and enable Flash Attention if supported
34
+ use_flash_attention = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
 
 
35
  model = LlamaForCausalLM.from_pretrained(
36
  MODEL_ID,
37
+ torch_dtype=torch.bfloat16,
38
  device_map="auto",
39
+ use_flash_attention_2=use_flash_attention,
40
+ load_in_8bit=True
41
  )
42
 
43
+ # Prepare model for LoRA training
44
  model = prepare_model_for_kbit_training(model)
 
 
45
  peft_config = LoraConfig(
46
+ r=16,
47
+ lora_alpha=32,
48
+ lora_dropout=0.05,
49
  bias="none",
50
  task_type="CAUSAL_LM",
51
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
52
  )
 
53
  model = get_peft_model(model, peft_config)
54
+ model.print_trainable_parameters()
55
 
56
+ # Function to process uploaded files and train
57
+ def train_ui(files):
58
  try:
59
+ # Process multiple PDFs or JSON
60
+ raw_text = ""
61
+ for file in files:
62
+ if file.name.endswith(".pdf"):
63
+ with pdfplumber.open(file.name) as pdf:
64
+ for page in pdf.pages:
65
+ raw_text += page.extract_text() or ""
66
+ elif file.name.endswith(".json"):
67
+ with open(file.name, "r", encoding="utf-8") as f:
68
+ raw_data = json.load(f)
69
+ training_data = raw_data.get("training_pairs", raw_data)
70
+ with open("temp_fraud_data.json", "w", encoding="utf-8") as f:
71
+ json.dump({"training_pairs": training_data}, f)
72
+ dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")
73
+
74
+ if not raw_text and not dataset:
75
+ return "Error: No valid PDF or JSON data found."
76
+
77
+ # Create training pairs from PDFs if no JSON
78
+ if raw_text:
79
+ def create_training_pairs(text):
80
+ pairs = []
81
+ if "Haloperidol" in text and "daily" in text.lower():
82
+ pairs.append({
83
+ "input": "Patient received Haloperidol daily. Is this overmedication?",
84
+ "output": "Yes, daily Haloperidol use without documented severe psychosis or failed alternatives may indicate overmedication, violating CMS guidelines."
85
+ })
86
+ if "Lorazepam" in text and "frequent" in text.lower():
87
+ pairs.append({
88
+ "input": "Care logs show frequent Lorazepam use with a 90-day supply. Is this suspicious?",
89
+ "output": "Yes, frequent use with a large supply suggests potential overuse or mismanagement, a fraud indicator."
90
+ })
91
+ return pairs
92
+ training_data = create_training_pairs(raw_text)
93
+ with open("temp_fraud_data.json", "w") as f:
94
+ json.dump({"training_pairs": training_data}, f)
95
+ dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")
96
+
97
+ # Tokenization function
98
  def tokenize_data(example):
 
99
  formatted_text = f"<s>[INST] {example['input']} [/INST] {example['output']}</s>"
100
+ inputs = tokenizer(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
 
 
 
 
 
 
101
  inputs["labels"] = inputs["input_ids"].clone()
102
  return {k: v.squeeze(0) for k, v in inputs.items()}
103
+
104
  tokenized_dataset = dataset["train"].map(tokenize_data, batched=True, remove_columns=dataset["train"].column_names)
105
+
106
+ # Training setup
107
  training_args = TrainingArguments(
108
+ output_dir="./fine_tuned_llama_healthcare",
109
+ per_device_train_batch_size=4,
110
+ gradient_accumulation_steps=8,
111
+ eval_strategy="no",
112
  save_strategy="epoch",
113
  save_total_limit=2,
114
+ num_train_epochs=5,
115
  learning_rate=2e-5,
116
  weight_decay=0.01,
117
  logging_dir="./logs",
118
  logging_steps=10,
119
+ bf16=True,
120
+ gradient_checkpointing=True,
121
  optim="adamw_torch",
122
  warmup_steps=100,
123
  )
124
+
 
125
  def custom_data_collator(features):
126
+ return {
127
  "input_ids": torch.stack([f["input_ids"] for f in features]),
128
  "attention_mask": torch.stack([f["attention_mask"] for f in features]),
129
  "labels": torch.stack([f["labels"] for f in features]),
130
  }
 
131
 
132
  trainer = Trainer(
133
  model=model,
 
135
  train_dataset=tokenized_dataset,
136
  data_collator=custom_data_collator,
137
  )
 
 
138
  trainer.train()
139
+ model.save_pretrained("./fine_tuned_llama_healthcare")
140
+ tokenizer.save_pretrained("./fine_tuned_llama_healthcare")
141
+ return "Training completed! Model saved to ./fine_tuned_llama_healthcare"
142
+
 
 
 
143
  except Exception as e:
144
+ return f"Error: {str(e)}. Please check file format, dependencies, or HF_TOKEN."
145
 
146
  # Gradio UI
147
+ with gr.Blocks(title="Healthcare Fraud Detection Fine-Tuning") as demo:
148
+ gr.Markdown("# Fine-Tune LLaMA 2 for Healthcare Fraud Analysis")
149
+ gr.Markdown("Upload PDFs (e.g., care logs, medication records) or a JSON file with training pairs.")
150
+ file_input = gr.File(label="Upload Files (PDF/JSON)", file_count="multiple")
 
151
  train_button = gr.Button("Start Fine-Tuning")
152
+ output = gr.Textbox(label="Training Status", lines=5)
153
+ train_button.click(fn=train_ui, inputs=file_input, outputs=output)
 
154
 
155
  demo.launch()