Cylanoid commited on
Commit
420d0a9
·
verified ·
1 Parent(s): 517984d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -38
app.py CHANGED
@@ -28,8 +28,7 @@ except ImportError as e:
28
  from accelerate import Accelerator
29
  import bitsandbytes
30
 
31
- # Rest of your script remains the same...
32
- # Model setup, training function, Gradio UI, etc., as shown in your previous script
33
  MODEL_ID = "meta-llama/Llama-2-7b-hf" # Use Llama-2-7b; switch to "meta-llama/Llama-3-8b-hf" for Llama 3
34
  tokenizer = LlamaTokenizer.from_pretrained(MODEL_ID)
35
 
@@ -122,39 +121,4 @@ def train_ui_tars(file):
122
  def custom_data_collator(features):
123
  batch = {
124
  "input_ids": torch.stack([f["input_ids"] for f in features]),
125
- "attention_mask": torch.stack([f["attention_mask"] for f in features]),
126
- "labels": torch.stack([f["labels"] for f in features]),
127
- }
128
- return batch
129
-
130
- trainer = Trainer(
131
- model=model,
132
- args=training_args,
133
- train_dataset=tokenized_dataset,
134
- data_collator=custom_data_collator,
135
- )
136
-
137
- # Step 4: Start training
138
- trainer.train()
139
-
140
- # Step 5: Save the model
141
- model.save_pretrained("./fine_tuned_llama")
142
- tokenizer.save_pretrained("./fine_tuned_llama")
143
-
144
- return "Training completed successfully! Model saved to ./fine_tuned_llama"
145
-
146
- except Exception as e:
147
- return f"Error: {str(e)}"
148
-
149
- # Gradio UI
150
- with gr.Blocks(title="Model Fine-Tuning Interface") as demo:
151
- gr.Markdown("# Llama Fraud Detection Fine-Tuning UI")
152
- gr.Markdown("Upload a JSON file with 'input' and 'output' pairs to fine-tune the Llama model on your fraud dataset.")
153
-
154
- file_input = gr.File(label="Upload Fraud Dataset (JSON)")
155
- train_button = gr.Button("Start Fine-Tuning")
156
- output = gr.Textbox(label="Training Status")
157
-
158
- train_button.click(fn=train_ui_tars, inputs=file_input, outputs=output)
159
-
160
- demo.launch()
 
28
  from accelerate import Accelerator
29
  import bitsandbytes
30
 
31
+ # Model setup
 
32
  MODEL_ID = "meta-llama/Llama-2-7b-hf" # Use Llama-2-7b; switch to "meta-llama/Llama-3-8b-hf" for Llama 3
33
  tokenizer = LlamaTokenizer.from_pretrained(MODEL_ID)
34
 
 
121
  def custom_data_collator(features):
122
  batch = {
123
  "input_ids": torch.stack([f["input_ids"] for f in features]),
124
+ "attention_mask": torch.stack([f["