Cylanoid commited on
Commit
4b6c42c
·
verified ·
1 Parent(s): 9b8712f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -157
app.py CHANGED
@@ -1,163 +1,44 @@
1
- # app.py
2
-
3
  import gradio as gr
4
- from transformers import LlamaForCausalLM, LlamaTokenizer
5
- import datasets
6
  import torch
7
- import json
8
- import os
9
- import pdfplumber
10
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
11
- from accelerate import Accelerator
12
- import bitsandbytes
13
- import sentencepiece
14
- import huggingface_hub
15
- from transformers import TrainingArguments, Trainer
16
-
17
- # Debug: Print all environment variables to verify 'LLama' is present
18
- print("Environment variables:", dict(os.environ))
19
-
20
- # Retrieve the token from Hugging Face Space secrets
21
- # Token placement: LLama:levi put token here
22
- LLama = os.getenv("LLama") # Retrieves the value of the 'LLama' environment variable
23
- if not LLama:
24
- raise ValueError("LLama token not found in environment variables. Please set it in Hugging Face Space secrets under 'Settings' > 'Secrets' as 'LLama'.")
25
-
26
- # Debug: Print the token to verify it's being read (remove this in production)
27
- print(f"Retrieved LLama token: {LLama[:5]}... (first 5 chars for security)")
28
-
29
- # Authenticate with Hugging Face
30
- huggingface_hub.login(token=LLama)
31
-
32
- # Model setup
33
- MODEL_ID = "meta-llama/Llama-2-7b-hf"
34
- tokenizer = LlamaTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
35
-
36
- # Load model with default attention mechanism (no Flash Attention)
37
- model = LlamaForCausalLM.from_pretrained(
38
- MODEL_ID,
39
- torch_dtype=torch.bfloat16,
40
- device_map="auto",
41
- load_in_8bit=True
42
- )
43
-
44
- # Add padding token if it doesn't exist and resize embeddings
45
- if tokenizer.pad_token is None:
46
- tokenizer.add_special_tokens({'pad_token': '[PAD]'})
47
- model.resize_token_embeddings(len(tokenizer))
48
-
49
- # Prepare model for LoRA training
50
- model = prepare_model_for_kbit_training(model)
51
- peft_config = LoraConfig(
52
- r=16,
53
- lora_alpha=32,
54
- lora_dropout=0.05,
55
- bias="none",
56
- task_type="CAUSAL_LM",
57
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
58
- )
59
- model = get_peft_model(model, peft_config)
60
- model.print_trainable_parameters()
61
 
62
- # Function to process uploaded files and train
63
- def train_ui(files):
 
 
 
 
 
 
 
 
 
 
 
64
  try:
65
- # Process multiple PDFs or JSON
66
- raw_text = ""
67
- dataset = None # Initialize dataset as None
68
- for file in files:
69
- if file.name.endswith(".pdf"):
70
- with pdfplumber.open(file.name) as pdf:
71
- for page in pdf.pages:
72
- raw_text += page.extract_text() or ""
73
- elif file.name.endswith(".json"):
74
- with open(file.name, "r", encoding="utf-8") as f:
75
- raw_data = json.load(f)
76
- training_data = raw_data.get("training_pairs", raw_data)
77
- with open("temp_fraud_data.json", "w", encoding="utf-8") as f:
78
- json.dump({"training_pairs": training_data}, f)
79
- dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")
80
-
81
- if not raw_text and not dataset:
82
- return "Error: No valid PDF or JSON data found."
83
-
84
- # Create training pairs from PDFs if no JSON
85
- if raw_text:
86
- def create_training_pairs(text):
87
- pairs = []
88
- if "Haloperidol" in text and "daily" in text.lower():
89
- pairs.append({
90
- "input": "Patient received Haloperidol daily. Is this overmedication?",
91
- "output": "Yes, daily Haloperidol use without documented severe psychosis or failed alternatives may indicate overmedication, violating CMS guidelines."
92
- })
93
- if "Lorazepam" in text and "frequent" in text.lower():
94
- pairs.append({
95
- "input": "Care logs show frequent Lorazepam use with a 90-day supply. Is this suspicious?",
96
- "output": "Yes, frequent use with a large supply suggests potential overuse or mismanagement, a fraud indicator."
97
- })
98
- return pairs
99
- training_data = create_training_pairs(raw_text)
100
- with open("temp_fraud_data.json", "w") as f:
101
- json.dump({"training_pairs": training_data}, f)
102
- dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")
103
-
104
- # Tokenization function
105
- def tokenize_data(example):
106
- formatted_text = f"<s>[INST] {example['input']} [/INST] {example['output']}</s>"
107
- inputs = tokenizer(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
108
- inputs["labels"] = inputs["input_ids"].clone()
109
- return {k: v.squeeze(0) for k, v in inputs.items()}
110
-
111
- tokenized_dataset = dataset["train"].map(tokenize_data, batched=True, remove_columns=dataset["train"].column_names)
112
-
113
- # Training setup
114
- training_args = TrainingArguments(
115
- output_dir="./fine_tuned_llama_healthcare",
116
- per_device_train_batch_size=4,
117
- gradient_accumulation_steps=8,
118
- eval_strategy="no",
119
- save_strategy="epoch",
120
- save_total_limit=2,
121
- num_train_epochs=5,
122
- learning_rate=2e-5,
123
- weight_decay=0.01,
124
- logging_dir="./logs",
125
- logging_steps=10,
126
- bf16=True,
127
- gradient_checkpointing=True,
128
- optim="adamw_torch",
129
- warmup_steps=100,
130
- )
131
-
132
- def custom_data_collator(features):
133
- return {
134
- "input_ids": torch.stack([f["input_ids"] for f in features]),
135
- "attention_mask": torch.stack([f["attention_mask"] for f in features]),
136
- "labels": torch.stack([f["labels"] for f in features]),
137
- }
138
-
139
- trainer = Trainer(
140
- model=model,
141
- args=training_args,
142
- train_dataset=tokenized_dataset,
143
- data_collator=custom_data_collator,
144
- )
145
- trainer.train()
146
- model.save_pretrained("./fine_tuned_llama_healthcare")
147
- tokenizer.save_pretrained("./fine_tuned_llama_healthcare")
148
- return "Training completed! Model saved to ./fine_tuned_llama_healthcare"
149
-
150
  except Exception as e:
151
- return f"Error: {str(e)}. Please check file format, dependencies, or the LLama token."
152
-
153
- # Gradio UI
154
- with gr.Blocks(title="Healthcare Fraud Detection Fine-Tuning") as demo:
155
- gr.Markdown("# Fine-Tune LLaMA 2 for Healthcare Fraud Analysis")
156
- gr.Markdown("Upload PDFs (e.g., care logs, medication records) or a JSON file with training pairs.")
157
- file_input = gr.File(label="Upload Files (PDF/JSON)", file_count="multiple")
158
- train_button = gr.Button("Start Fine-Tuning")
159
- output = gr.Textbox(label="Training Status", lines=5)
160
- train_button.click(fn=train_ui, inputs=file_input, outputs=output)
 
 
 
 
161
 
162
- # Launch the Gradio app
163
- demo.launch()
 
 
 
1
  import gradio as gr
2
+ from transformers import LlamaTokenizer, LlamaForCausalLM
 
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ # Load the fine-tuned model and tokenizer
6
+ try:
7
+ tokenizer = LlamaTokenizer.from_pretrained("./fine_tuned_llama2")
8
+ model = LlamaForCausalLM.from_pretrained("./fine_tuned_llama2")
9
+ model.eval()
10
+ print("Model and tokenizer loaded successfully.")
11
+ except Exception as e:
12
+ print(f"Error loading model or tokenizer: {e}")
13
+
14
+ # Function to predict fraud based on text input
15
+ def predict(input_text):
16
+ if not input_text:
17
+ return "Please enter some text to analyze."
18
  try:
19
+ # Tokenize input
20
+ inputs = tokenizer(input_text, return_tensors="pt", max_length=512, padding="max_length", truncation=True)
21
+ # Generate output
22
+ with torch.no_grad():
23
+ outputs = model.generate(**inputs, max_new_tokens=50)
24
+ # Decode and return result
25
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  except Exception as e:
28
+ return f"Error during prediction: {e}"
29
+
30
+ # Create Gradio interface with text input
31
+ interface = gr.Interface(
32
+ fn=predict,
33
+ inputs=gr.Textbox(
34
+ lines=2,
35
+ placeholder="Enter text to analyze (e.g., 'Facility backdates policies. Is this fraudulent?')",
36
+ label="Input Text"
37
+ ),
38
+ outputs=gr.Textbox(label="Prediction"),
39
+ title="Fine-Tune LLaMA 2 for Healthcare Fraud Analysis",
40
+ description="Test the fine-tuned LLaMA 2 model to detect healthcare fraud. Enter a description of a facility's behavior to analyze."
41
+ )
42
 
43
+ # Launch the interface
44
+ interface.launch()