Update app.py
Browse files
app.py
CHANGED
@@ -28,8 +28,7 @@ except ImportError as e:
|
|
28 |
from accelerate import Accelerator
|
29 |
import bitsandbytes
|
30 |
|
31 |
-
#
|
32 |
-
# Model setup, training function, Gradio UI, etc., as shown in your previous script
|
33 |
MODEL_ID = "meta-llama/Llama-2-7b-hf" # Use Llama-2-7b; switch to "meta-llama/Llama-3-8b-hf" for Llama 3
|
34 |
tokenizer = LlamaTokenizer.from_pretrained(MODEL_ID)
|
35 |
|
@@ -122,39 +121,4 @@ def train_ui_tars(file):
|
|
122 |
def custom_data_collator(features):
|
123 |
batch = {
|
124 |
"input_ids": torch.stack([f["input_ids"] for f in features]),
|
125 |
-
"attention_mask": torch.stack([f["
|
126 |
-
"labels": torch.stack([f["labels"] for f in features]),
|
127 |
-
}
|
128 |
-
return batch
|
129 |
-
|
130 |
-
trainer = Trainer(
|
131 |
-
model=model,
|
132 |
-
args=training_args,
|
133 |
-
train_dataset=tokenized_dataset,
|
134 |
-
data_collator=custom_data_collator,
|
135 |
-
)
|
136 |
-
|
137 |
-
# Step 4: Start training
|
138 |
-
trainer.train()
|
139 |
-
|
140 |
-
# Step 5: Save the model
|
141 |
-
model.save_pretrained("./fine_tuned_llama")
|
142 |
-
tokenizer.save_pretrained("./fine_tuned_llama")
|
143 |
-
|
144 |
-
return "Training completed successfully! Model saved to ./fine_tuned_llama"
|
145 |
-
|
146 |
-
except Exception as e:
|
147 |
-
return f"Error: {str(e)}"
|
148 |
-
|
149 |
-
# Gradio UI
|
150 |
-
with gr.Blocks(title="Model Fine-Tuning Interface") as demo:
|
151 |
-
gr.Markdown("# Llama Fraud Detection Fine-Tuning UI")
|
152 |
-
gr.Markdown("Upload a JSON file with 'input' and 'output' pairs to fine-tune the Llama model on your fraud dataset.")
|
153 |
-
|
154 |
-
file_input = gr.File(label="Upload Fraud Dataset (JSON)")
|
155 |
-
train_button = gr.Button("Start Fine-Tuning")
|
156 |
-
output = gr.Textbox(label="Training Status")
|
157 |
-
|
158 |
-
train_button.click(fn=train_ui_tars, inputs=file_input, outputs=output)
|
159 |
-
|
160 |
-
demo.launch()
|
|
|
28 |
from accelerate import Accelerator
|
29 |
import bitsandbytes
|
30 |
|
31 |
+
# Model setup
|
|
|
32 |
MODEL_ID = "meta-llama/Llama-2-7b-hf" # Use Llama-2-7b; switch to "meta-llama/Llama-3-8b-hf" for Llama 3
|
33 |
tokenizer = LlamaTokenizer.from_pretrained(MODEL_ID)
|
34 |
|
|
|
121 |
def custom_data_collator(features):
|
122 |
batch = {
|
123 |
"input_ids": torch.stack([f["input_ids"] for f in features]),
|
124 |
+
"attention_mask": torch.stack([f["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|