rshakked commited on
Commit
96bc19d
Β·
1 Parent(s): 9a19cb3

fix: ensure trained model is correctly detected using absolute path

Browse files
Files changed (1) hide show
  1. train_abuse_model.py +4 -15
train_abuse_model.py CHANGED
@@ -41,7 +41,8 @@ from utils import (
41
  tune_thresholds,
42
  label_map,
43
  label_row_soft,
44
- AbuseDataset
 
45
  )
46
 
47
  # Create evaluation results directory if it doesn't exist
@@ -52,18 +53,6 @@ MODEL_DIR = PERSIST_DIR / "saved_model"
52
  LOG_FILE = PERSIST_DIR / "training.log"
53
 
54
 
55
- # Save and print evaluation results
56
- def save_and_yield_eval(report: str):
57
- # Generate versioned filename using timestamp
58
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
59
-
60
- eval_filename = f"eval_report_{timestamp}.txt"
61
- eval_filepath = Path("/home/user/app/results_eval") / eval_filename
62
-
63
- with open(eval_filepath, "w") as f:
64
- f.write(report)
65
- yield f"πŸ“„ Evaluation saved to: {eval_filepath.name}"
66
- yield report
67
  # configure logging
68
  log_buffer = io.StringIO()
69
  logging.basicConfig(
@@ -135,7 +124,7 @@ def load_saved_model_and_tokenizer():
135
  return tokenizer, model
136
 
137
  def evaluate_saved_model(progress=gr.Progress(track_tqdm=True)):
138
- if os.path.exists("saved_model/"):
139
  yield "βœ… Trained model found! Skipping training...\n"
140
  else:
141
  yield "❌ No trained model found. Please train the model first.\n"
@@ -218,7 +207,7 @@ model_name = "microsoft/deberta-v3-base"
218
 
219
  def run_training(progress=gr.Progress(track_tqdm=True)):
220
  log_queue = queue.Queue()
221
- if os.path.exists("saved_model/"):
222
  yield "βœ… Trained model found! Skipping training...\n"
223
  for line in evaluate_saved_model():
224
  yield line
 
41
  tune_thresholds,
42
  label_map,
43
  label_row_soft,
44
+ AbuseDataset,
45
+ save_and_yield_eval
46
  )
47
 
48
  # Create evaluation results directory if it doesn't exist
 
53
  LOG_FILE = PERSIST_DIR / "training.log"
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # configure logging
57
  log_buffer = io.StringIO()
58
  logging.basicConfig(
 
124
  return tokenizer, model
125
 
126
  def evaluate_saved_model(progress=gr.Progress(track_tqdm=True)):
127
+ if MODEL_DIR.exists():
128
  yield "βœ… Trained model found! Skipping training...\n"
129
  else:
130
  yield "❌ No trained model found. Please train the model first.\n"
 
207
 
208
  def run_training(progress=gr.Progress(track_tqdm=True)):
209
  log_queue = queue.Queue()
210
+ if MODEL_DIR.exists():
211
  yield "βœ… Trained model found! Skipping training...\n"
212
  for line in evaluate_saved_model():
213
  yield line