Update app.py
Browse files
app.py
CHANGED
@@ -117,8 +117,43 @@ def train_ui_tars(file):
|
|
117 |
warmup_steps=100,
|
118 |
)
|
119 |
|
120 |
-
# Custom data collator for Llama
|
121 |
def custom_data_collator(features):
|
122 |
batch = {
|
123 |
"input_ids": torch.stack([f["input_ids"] for f in features]),
|
124 |
-
"attention_mask": torch.stack([f["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
warmup_steps=100,
|
118 |
)
|
119 |
|
120 |
+
# Custom data collator for Llama (corrected)
|
121 |
def custom_data_collator(features):
|
122 |
batch = {
|
123 |
"input_ids": torch.stack([f["input_ids"] for f in features]),
|
124 |
+
"attention_mask": torch.stack([f["attention_mask"] for f in features]),
|
125 |
+
"labels": torch.stack([f["labels"] for f in features]),
|
126 |
+
}
|
127 |
+
return batch
|
128 |
+
|
129 |
+
trainer = Trainer(
|
130 |
+
model=model,
|
131 |
+
args=training_args,
|
132 |
+
train_dataset=tokenized_dataset,
|
133 |
+
data_collator=custom_data_collator,
|
134 |
+
)
|
135 |
+
|
136 |
+
# Step 4: Start training
|
137 |
+
trainer.train()
|
138 |
+
|
139 |
+
# Step 5: Save the model
|
140 |
+
model.save_pretrained("./fine_tuned_llama")
|
141 |
+
tokenizer.save_pretrained("./fine_tuned_llama")
|
142 |
+
|
143 |
+
return "Training completed successfully! Model saved to ./fine_tuned_llama"
|
144 |
+
|
145 |
+
except Exception as e:
|
146 |
+
return f"Error: {str(e)}"
|
147 |
+
|
148 |
+
# Gradio UI
|
149 |
+
with gr.Blocks(title="Model Fine-Tuning Interface") as demo:
|
150 |
+
gr.Markdown("# Llama Fraud Detection Fine-Tuning UI")
|
151 |
+
gr.Markdown("Upload a JSON file with 'input' and 'output' pairs to fine-tune the Llama model on your fraud dataset.")
|
152 |
+
|
153 |
+
file_input = gr.File(label="Upload Fraud Dataset (JSON)")
|
154 |
+
train_button = gr.Button("Start Fine-Tuning")
|
155 |
+
output = gr.Textbox(label="Training Status")
|
156 |
+
|
157 |
+
train_button.click(fn=train_ui_tars, inputs=file_input, outputs=output)
|
158 |
+
|
159 |
+
demo.launch()
|