Cylanoid commited on
Commit
e6fa3be
·
verified ·
1 Parent(s): 420d0a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -2
app.py CHANGED
@@ -117,8 +117,43 @@ def train_ui_tars(file):
117
  warmup_steps=100,
118
  )
119
 
120
- # Custom data collator for Llama
121
  def custom_data_collator(features):
122
  batch = {
123
  "input_ids": torch.stack([f["input_ids"] for f in features]),
124
- "attention_mask": torch.stack([f["
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  warmup_steps=100,
118
  )
119
 
120
+ # Custom data collator for Llama (corrected)
121
  def custom_data_collator(features):
122
  batch = {
123
  "input_ids": torch.stack([f["input_ids"] for f in features]),
124
+ "attention_mask": torch.stack([f["attention_mask"] for f in features]),
125
+ "labels": torch.stack([f["labels"] for f in features]),
126
+ }
127
+ return batch
128
+
129
+ trainer = Trainer(
130
+ model=model,
131
+ args=training_args,
132
+ train_dataset=tokenized_dataset,
133
+ data_collator=custom_data_collator,
134
+ )
135
+
136
+ # Step 4: Start training
137
+ trainer.train()
138
+
139
+ # Step 5: Save the model
140
+ model.save_pretrained("./fine_tuned_llama")
141
+ tokenizer.save_pretrained("./fine_tuned_llama")
142
+
143
+ return "Training completed successfully! Model saved to ./fine_tuned_llama"
144
+
145
+ except Exception as e:
146
+ return f"Error: {str(e)}"
147
+
148
+ # Gradio UI
149
+ with gr.Blocks(title="Model Fine-Tuning Interface") as demo:
150
+ gr.Markdown("# Llama Fraud Detection Fine-Tuning UI")
151
+ gr.Markdown("Upload a JSON file with 'input' and 'output' pairs to fine-tune the Llama model on your fraud dataset.")
152
+
153
+ file_input = gr.File(label="Upload Fraud Dataset (JSON)")
154
+ train_button = gr.Button("Start Fine-Tuning")
155
+ output = gr.Textbox(label="Training Status")
156
+
157
+ train_button.click(fn=train_ui_tars, inputs=file_input, outputs=output)
158
+
159
+ demo.launch()