Cylanoid commited on
Commit
bf713b8
·
verified ·
1 Parent(s): 53d6f71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -29
app.py CHANGED
@@ -1,17 +1,65 @@
1
- import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
3
- import datasets
4
- import torch
5
- import json
6
- import os
7
- import accelerate
8
- except ImportError:
9
- os.system('pip install "accelerate>=0.26.0"')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Model setup
12
- MODEL_ID = "facebook/opt-350m" # Smaller, open access model
13
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
14
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Function to process uploaded JSON and train
17
  def train_ui_tars(file):
@@ -31,51 +79,73 @@ def train_ui_tars(file):
31
  # Load dataset
32
  dataset = datasets.load_dataset("json", data_files=fixed_json_path)
33
 
34
- # Step 2: Tokenize dataset
35
  def tokenize_data(example):
36
- inputs = tokenizer(example["input"], padding="max_length", truncation=True, max_length=512)
37
- targets = tokenizer(example["output"], padding="max_length", truncation=True, max_length=512)
38
- inputs["labels"] = targets["input_ids"]
39
- return inputs
 
 
 
 
 
 
 
40
 
41
- tokenized_dataset = dataset.map(tokenize_data, batched=True)
42
 
43
  # Step 3: Training setup
44
  training_args = TrainingArguments(
45
- output_dir="./fine_tuned_llama2",
46
- per_device_train_batch_size=2,
 
47
  evaluation_strategy="no",
48
  save_strategy="epoch",
49
  save_total_limit=2,
50
  num_train_epochs=3,
51
  learning_rate=2e-5,
52
  weight_decay=0.01,
53
- logging_dir="./logs"
 
 
 
 
 
54
  )
55
 
 
 
 
 
 
 
 
 
 
56
  trainer = Trainer(
57
  model=model,
58
  args=training_args,
59
- train_dataset=tokenized_dataset["train"],
60
- data_collator=DataCollatorForSeq2Seq(tokenizer, model=model)
61
  )
62
 
63
  # Step 4: Start training
64
  trainer.train()
65
 
66
  # Step 5: Save the model
67
- model.save_pretrained("train_llama.py")
68
- tokenizer.save_pretrained("./train_llama.py")
69
 
70
- return "Training completed successfully! Model saved to ./train_llama.py"
71
 
72
  except Exception as e:
73
  return f"Error: {str(e)}"
74
 
75
  # Gradio UI
76
  with gr.Blocks(title="Model Fine-Tuning Interface") as demo:
77
- gr.Markdown("train_llama.py")
78
- gr.Markdown("Upload a JSON file with 'input' and 'output' pairs to fine-tune the model on your fraud dataset.")
79
 
80
  file_input = gr.File(label="Upload Fraud Dataset (JSON)")
81
  train_button = gr.Button("Start Fine-Tuning")
@@ -83,5 +153,4 @@ with gr.Blocks(title="Model Fine-Tuning Interface") as demo:
83
 
84
  train_button.click(fn=train_ui_tars, inputs=file_input, outputs=output)
85
 
86
- # Launch the app
87
  demo.launch()
 
1
+ # app.py
2
+
3
+ # Handle missing dependencies first
4
+ try:
5
+ import gradio as gr
6
+ from transformers import LlamaForCausalLM, LlamaTokenizer, Trainer, TrainingArguments
7
+ import datasets
8
+ import torch
9
+ import json
10
+ import os
11
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
12
+ from accelerate import Accelerator
13
+ import bitsandbytes
14
+ except ImportError as e:
15
+ missing_package = str(e).split("'")[-2] # Extract the missing package name
16
+ os.system(f'pip install "{missing_package}>=0.26.0" if "accelerate" in missing_package else f'pip install {missing_package}')
17
+ # Re-import after installation
18
+ import gradio as gr
19
+ from transformers import LlamaForCausalLM, LlamaTokenizer, Trainer, TrainingArguments
20
+ import datasets
21
+ import torch
22
+ import json
23
+ import os
24
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
25
+ from accelerate import Accelerator
26
+ import bitsandbytes
27
 
28
  # Model setup
29
+ MODEL_ID = "meta-llama/Llama-2-7b-hf" # Use Llama-2-7b; switch to "meta-llama/Llama-3-8b-hf" for Llama 3
30
+ tokenizer = LlamaTokenizer.from_pretrained(MODEL_ID)
31
+
32
+ # Add padding token if it doesn't exist (required for Llama models)
33
+ if tokenizer.pad_token is None:
34
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
35
+
36
+ # Check if CUDA is available to enable Flash Attention 2
37
+ use_flash_attention = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 # Ampere or newer (e.g., A100)
38
+
39
+ # Load the model with optimizations for Llama
40
+ model = LlamaForCausalLM.from_pretrained(
41
+ MODEL_ID,
42
+ torch_dtype=torch.bfloat16, # Better for A100 GPUs, falls back to float16 on CPU
43
+ device_map="auto",
44
+ use_flash_attention_2=use_flash_attention, # Only enable if GPU supports it
45
+ load_in_8bit=True # Quantization for memory efficiency
46
+ )
47
+
48
+ # Prepare the model for training with LoRA (more memory-efficient)
49
+ model = prepare_model_for_kbit_training(model)
50
+
51
+ # LoRA configuration
52
+ peft_config = LoraConfig(
53
+ r=16, # Rank
54
+ lora_alpha=32, # Alpha
55
+ lora_dropout=0.05, # Dropout
56
+ bias="none",
57
+ task_type="CAUSAL_LM",
58
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"] # Attention modules for Llama
59
+ )
60
+
61
+ model = get_peft_model(model, peft_config)
62
+ model.print_trainable_parameters() # Print percentage of trainable parameters
63
 
64
  # Function to process uploaded JSON and train
65
  def train_ui_tars(file):
 
79
  # Load dataset
80
  dataset = datasets.load_dataset("json", data_files=fixed_json_path)
81
 
82
+ # Step 2: Tokenize dataset with Llama-compatible context length
83
  def tokenize_data(example):
84
+ # Format input for Llama (instruction-following style)
85
+ formatted_text = f"<s>[INST] {example['input']} [/INST] {example['output']}</s>"
86
+ inputs = tokenizer(
87
+ formatted_text,
88
+ padding="max_length",
89
+ truncation=True,
90
+ max_length=2048, # Llama 2 context length; adjust to 8192 for Llama 3 if needed
91
+ return_tensors="pt"
92
+ )
93
+ inputs["labels"] = inputs["input_ids"].clone()
94
+ return {k: v.squeeze(0) for k, v in inputs.items()}
95
 
96
+ tokenized_dataset = dataset["train"].map(tokenize_data, batched=True, remove_columns=dataset["train"].column_names)
97
 
98
  # Step 3: Training setup
99
  training_args = TrainingArguments(
100
+ output_dir="./fine_tuned_llama",
101
+ per_device_train_batch_size=4, # Increased for better efficiency
102
+ gradient_accumulation_steps=8, # To handle larger effective batch size
103
  evaluation_strategy="no",
104
  save_strategy="epoch",
105
  save_total_limit=2,
106
  num_train_epochs=3,
107
  learning_rate=2e-5,
108
  weight_decay=0.01,
109
+ logging_dir="./logs",
110
+ logging_steps=10,
111
+ bf16=True, # Use bfloat16 for A100 GPUs, falls back to float16 on CPU
112
+ gradient_checkpointing=True, # Memory optimization
113
+ optim="adamw_torch",
114
+ warmup_steps=100,
115
  )
116
 
117
+ # Custom data collator for Llama
118
+ def custom_data_collator(features):
119
+ batch = {
120
+ "input_ids": torch.stack([f["input_ids"] for f in features]),
121
+ "attention_mask": torch.stack([f["attention_mask"] for f in features]),
122
+ "labels": torch.stack([f["labels"] for f in features]),
123
+ }
124
+ return batch
125
+
126
  trainer = Trainer(
127
  model=model,
128
  args=training_args,
129
+ train_dataset=tokenized_dataset,
130
+ data_collator=custom_data_collator,
131
  )
132
 
133
  # Step 4: Start training
134
  trainer.train()
135
 
136
  # Step 5: Save the model
137
+ model.save_pretrained("./fine_tuned_llama")
138
+ tokenizer.save_pretrained("./fine_tuned_llama")
139
 
140
+ return "Training completed successfully! Model saved to ./fine_tuned_llama"
141
 
142
  except Exception as e:
143
  return f"Error: {str(e)}"
144
 
145
  # Gradio UI
146
  with gr.Blocks(title="Model Fine-Tuning Interface") as demo:
147
+ gr.Markdown("# Llama Fraud Detection Fine-Tuning UI")
148
+ gr.Markdown("Upload a JSON file with 'input' and 'output' pairs to fine-tune the Llama model on your fraud dataset.")
149
 
150
  file_input = gr.File(label="Upload Fraud Dataset (JSON)")
151
  train_button = gr.Button("Start Fine-Tuning")
 
153
 
154
  train_button.click(fn=train_ui_tars, inputs=file_input, outputs=output)
155
 
 
156
  demo.launch()