rshakked commited on
Commit
54299e5
Β·
1 Parent(s): 1294c96

feat: persist model and logs in Hugging Face Space + add model push to hub

Browse files

- Updated paths to save model and logs to /home/user/app (persistent in Spaces)
- Modified logging to stream to both file and UI log buffer
- Updated model saving/loading to use MODEL_DIR inside the persistent path
- Added push_model_to_hub() to upload trained model/tokenizer to Hugging Face Hub
- Extended Gradio UI with 'Evaluate Model' and 'Push Model to Hub' buttons

Files changed (2) hide show
  1. app.py +8 -5
  2. train_abuse_model.py +86 -15
app.py CHANGED
@@ -1,17 +1,20 @@
1
  import gradio as gr
2
- from train_abuse_model import run_training
3
 
4
  with gr.Blocks() as demo:
5
  gr.Markdown("## 🧠 Abuse Detection Fine-Tuning App")
6
- gr.Markdown(
7
- "⚠️ **Important:** Keep this tab open and prevent your computer from sleeping while training runs."
8
- )
9
  with gr.Row():
10
  start_btn = gr.Button("πŸš€ Start Training")
 
 
11
 
12
- output_box = gr.Textbox(label="Live Training Logs", lines=25, interactive=False)
13
 
14
  start_btn.click(fn=run_training, outputs=output_box)
 
 
15
 
16
  if __name__ == "__main__":
17
  demo.launch()
 
1
  import gradio as gr
2
+ from train_abuse_model import run_training, evaluate_saved_model, push_model_to_hub
3
 
4
  with gr.Blocks() as demo:
5
  gr.Markdown("## 🧠 Abuse Detection Fine-Tuning App")
6
+ gr.Markdown("⚠️ Keep this tab open while training or evaluating.")
7
+
 
8
  with gr.Row():
9
  start_btn = gr.Button("πŸš€ Start Training")
10
+ eval_btn = gr.Button("πŸ” Evaluate Trained Model")
11
+ push_btn = gr.Button("πŸ“€ Push Model to Hub")
12
 
13
+ output_box = gr.Textbox(label="Logs", lines=25, interactive=False)
14
 
15
  start_btn.click(fn=run_training, outputs=output_box)
16
+ eval_btn.click(fn=evaluate_saved_model, outputs=output_box)
17
+ push_btn.click(fn=push_model_to_hub, outputs=output_box)
18
 
19
  if __name__ == "__main__":
20
  demo.launch()
train_abuse_model.py CHANGED
@@ -5,6 +5,7 @@ import logging
5
  import io
6
  import os
7
  import gradio as gr # βœ… required for progress bar
 
8
 
9
  # Python standard + ML packages
10
  import pandas as pd
@@ -30,18 +31,23 @@ from transformers import (
30
  TrainingArguments
31
  )
32
 
 
 
 
 
33
  # configure logging
34
- log_buffer = io.StringIO()
35
  logging.basicConfig(
36
  level=logging.INFO,
37
  format="%(asctime)s - %(levelname)s - %(message)s",
38
  handlers=[
39
- logging.FileHandler("training.log"), # to file
40
- logging.StreamHandler(log_buffer) # to in-memory buffer
41
  ]
42
  )
43
  logger = logging.getLogger(__name__)
44
 
 
45
  # Check versions
46
  logger.info("Transformers version:", transformers.__version__)
47
 
@@ -50,6 +56,9 @@ logger.info("Transformers version: %s", torch.__version__)
50
  logger.info("torch.cuda.is_available(): %s", torch.cuda.is_available())
51
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
 
 
 
 
53
  # Custom Dataset class
54
 
55
  class AbuseDataset(Dataset):
@@ -127,33 +136,81 @@ def tune_thresholds(probs, true_labels, verbose=True):
127
  def evaluate_model_with_thresholds(trainer, test_dataset):
128
  """Run full evaluation with automatic threshold tuning."""
129
  logger.info("\nπŸ” Running model predictions...")
 
 
130
  predictions = trainer.predict(test_dataset)
131
  probs = torch.sigmoid(torch.tensor(predictions.predictions)).numpy()
132
  true_soft = np.array(predictions.label_ids)
133
 
134
  logger.info("\nπŸ”Ž Tuning thresholds...")
 
135
  best_low, best_high, best_f1 = tune_thresholds(probs, true_soft)
136
 
137
  logger.info(f"\nβœ… Best thresholds: low={best_low:.2f}, high={best_high:.2f} (macro F1={best_f1:.3f})")
 
138
 
139
  final_pred_soft = map_to_3_classes(probs, best_low, best_high)
140
  final_pred_str = convert_to_label_strings(final_pred_soft)
141
  true_str = convert_to_label_strings(true_soft)
142
 
143
  logger.info("\nπŸ“Š Final Evaluation Report (multi-class per label):\n")
 
144
  logger.info(classification_report(
145
  true_str,
146
  final_pred_str,
147
  labels=["no", "plausibly", "yes"],
 
148
  zero_division=0
149
  ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- return {
152
- "thresholds": (best_low, best_high),
153
- "macro_f1": best_f1,
154
- "true_labels": true_str,
155
- "pred_labels": final_pred_str
156
- }
157
 
158
  token = os.environ.get("HF_TOKEN") # Reads my token from a secure hf secret
159
 
@@ -202,6 +259,10 @@ train_texts, val_texts, train_labels, val_labels = train_test_split(
202
  model_name = "microsoft/deberta-v3-base"
203
 
204
  def run_training(progress=gr.Progress(track_tqdm=True)):
 
 
 
 
205
  yield "πŸš€ Starting training...\n"
206
  try:
207
  logger.info("Starting training run...")
@@ -269,11 +330,10 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
269
  trainer.train()
270
 
271
  # Save the model and tokenizer
272
- if not os.path.exists("saved_model/"):
273
- os.makedirs("saved_model/")
274
- model.save_pretrained("saved_model/")
275
- tokenizer.save_pretrained("saved_model/")
276
-
277
  logger.info(" Training completed and model saved.")
278
  yield "πŸŽ‰ Training complete! Model saved.\n"
279
 
@@ -284,7 +344,6 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
284
  # Evaluation
285
  try:
286
  if 'trainer' in locals():
287
- label_map = {0.0: "no", 0.5: "plausibly", 1.0: "yes"}
288
  evaluate_model_with_thresholds(trainer, test_dataset)
289
  logger.info("Evaluation completed")
290
  except Exception as e:
@@ -292,3 +351,15 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
292
  log_buffer.seek(0)
293
  return log_buffer.read()
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import io
6
  import os
7
  import gradio as gr # βœ… required for progress bar
8
+ from pathlib import Path
9
 
10
  # Python standard + ML packages
11
  import pandas as pd
 
31
  TrainingArguments
32
  )
33
 
34
+ PERSIST_DIR = Path("/home/user/app")
35
+ MODEL_DIR = PERSIST_DIR / "saved_model"
36
+ LOG_FILE = PERSIST_DIR / "training.log"
37
+
38
  # configure logging
39
+ log_buffer = io.StringIO()
40
  logging.basicConfig(
41
  level=logging.INFO,
42
  format="%(asctime)s - %(levelname)s - %(message)s",
43
  handlers=[
44
+ logging.FileHandler(LOG_FILE),
45
+ logging.StreamHandler(log_buffer)
46
  ]
47
  )
48
  logger = logging.getLogger(__name__)
49
 
50
+
51
  # Check versions
52
  logger.info("Transformers version:", transformers.__version__)
53
 
 
56
  logger.info("torch.cuda.is_available(): %s", torch.cuda.is_available())
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
 
59
+ # Label mapping for evaluation
60
+ label_map = {0.0: "no", 0.5: "plausibly", 1.0: "yes"}
61
+
62
  # Custom Dataset class
63
 
64
  class AbuseDataset(Dataset):
 
136
  def evaluate_model_with_thresholds(trainer, test_dataset):
137
  """Run full evaluation with automatic threshold tuning."""
138
  logger.info("\nπŸ” Running model predictions...")
139
+ yield "\nπŸ” Running model predictions..."
140
+
141
  predictions = trainer.predict(test_dataset)
142
  probs = torch.sigmoid(torch.tensor(predictions.predictions)).numpy()
143
  true_soft = np.array(predictions.label_ids)
144
 
145
  logger.info("\nπŸ”Ž Tuning thresholds...")
146
+ yield "\nπŸ”Ž Tuning thresholds..."
147
  best_low, best_high, best_f1 = tune_thresholds(probs, true_soft)
148
 
149
  logger.info(f"\nβœ… Best thresholds: low={best_low:.2f}, high={best_high:.2f} (macro F1={best_f1:.3f})")
150
+ yield f"\nβœ… Best thresholds: low={best_low:.2f}, high={best_high:.2f} (macro F1={best_f1:.3f})"
151
 
152
  final_pred_soft = map_to_3_classes(probs, best_low, best_high)
153
  final_pred_str = convert_to_label_strings(final_pred_soft)
154
  true_str = convert_to_label_strings(true_soft)
155
 
156
  logger.info("\nπŸ“Š Final Evaluation Report (multi-class per label):\n")
157
+ yield "\nπŸ“Š Final Evaluation Report (multi-class per label):\n "
158
  logger.info(classification_report(
159
  true_str,
160
  final_pred_str,
161
  labels=["no", "plausibly", "yes"],
162
+ digits=3,
163
  zero_division=0
164
  ))
165
+ yield classification_report(
166
+ true_str,
167
+ final_pred_str,
168
+ labels=["no", "plausibly", "yes"],
169
+ digits=3,
170
+ zero_division=0
171
+ )
172
+ def load_saved_model_and_tokenizer():
173
+ tokenizer = DebertaV2Tokenizer.from_pretrained(MODEL_DIR)
174
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR).to(device)
175
+ return tokenizer, model
176
+
177
+ def evaluate_saved_model(progress=gr.Progress(track_tqdm=True)):
178
+ if os.path.exists("saved_model/"):
179
+ yield "βœ… Trained model found! Skipping training...\n"
180
+ else:
181
+ yield "❌ No trained model found. Please train the model first.\n"
182
+ return
183
+ try:
184
+ logger.info("πŸ” Loading saved model for evaluation...")
185
+ yield "πŸ” Loading saved model for evaluation...\n"
186
+
187
+ tokenizer, model = load_saved_model_and_tokenizer()
188
+ test_dataset = AbuseDataset(test_texts, test_labels, tokenizer)
189
+
190
+ trainer = Trainer(
191
+ model=model,
192
+ args=TrainingArguments(
193
+ output_dir="./results_eval",
194
+ per_device_eval_batch_size=4,
195
+ logging_dir="./logs_eval",
196
+ disable_tqdm=True
197
+ ),
198
+ eval_dataset=test_dataset
199
+ )
200
+
201
+ label_map = {0.0: "no", 0.5: "plausibly", 1.0: "yes"}
202
+
203
+ # Re-yield from generator
204
+ for line in evaluate_model_with_thresholds(trainer, test_dataset):
205
+ yield line
206
+
207
+ logger.info("βœ… Evaluation complete.\n")
208
+ yield "\nβœ… Evaluation complete.\n"
209
+
210
+ except Exception as e:
211
+ logger.exception(f"❌ Evaluation failed: {e}")
212
+ yield f"❌ Evaluation failed: {e}\n"
213
 
 
 
 
 
 
 
214
 
215
  token = os.environ.get("HF_TOKEN") # Reads my token from a secure hf secret
216
 
 
259
  model_name = "microsoft/deberta-v3-base"
260
 
261
  def run_training(progress=gr.Progress(track_tqdm=True)):
262
+ if os.path.exists("saved_model/"):
263
+ yield "βœ… Trained model found! Skipping training...\n"
264
+ yield evaluate_saved_model()
265
+ return
266
  yield "πŸš€ Starting training...\n"
267
  try:
268
  logger.info("Starting training run...")
 
330
  trainer.train()
331
 
332
  # Save the model and tokenizer
333
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
334
+ model.save_pretrained(MODEL_DIR)
335
+ tokenizer.save_pretrained(MODEL_DIR)
336
+
 
337
  logger.info(" Training completed and model saved.")
338
  yield "πŸŽ‰ Training complete! Model saved.\n"
339
 
 
344
  # Evaluation
345
  try:
346
  if 'trainer' in locals():
 
347
  evaluate_model_with_thresholds(trainer, test_dataset)
348
  logger.info("Evaluation completed")
349
  except Exception as e:
 
351
  log_buffer.seek(0)
352
  return log_buffer.read()
353
 
354
+ def push_model_to_hub():
355
+ try:
356
+ logger.info("πŸ”„ Pushing model to Hugging Face Hub...")
357
+ tokenizer, model = load_saved_model_and_tokenizer()
358
+ model.push_to_hub("rshakked/safe-talk", use_auth_token=token)
359
+ tokenizer.push_to_hub("rshakked/safe-talk", use_auth_token=token)
360
+ return "βœ… Model pushed to hub successfully!"
361
+ except Exception as e:
362
+ logger.exception("❌ Failed to push model to hub.")
363
+ return f"❌ Failed to push model: {e}"
364
+
365
+