fix: ensure trained model is correctly detected using absolute path
Browse files- train_abuse_model.py +4 -15
train_abuse_model.py
CHANGED
@@ -41,7 +41,8 @@ from utils import (
|
|
41 |
tune_thresholds,
|
42 |
label_map,
|
43 |
label_row_soft,
|
44 |
-
AbuseDataset
|
|
|
45 |
)
|
46 |
|
47 |
# Create evaluation results directory if it doesn't exist
|
@@ -52,18 +53,6 @@ MODEL_DIR = PERSIST_DIR / "saved_model"
|
|
52 |
LOG_FILE = PERSIST_DIR / "training.log"
|
53 |
|
54 |
|
55 |
-
# Save and print evaluation results
|
56 |
-
def save_and_yield_eval(report: str):
|
57 |
-
# Generate versioned filename using timestamp
|
58 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
59 |
-
|
60 |
-
eval_filename = f"eval_report_{timestamp}.txt"
|
61 |
-
eval_filepath = Path("/home/user/app/results_eval") / eval_filename
|
62 |
-
|
63 |
-
with open(eval_filepath, "w") as f:
|
64 |
-
f.write(report)
|
65 |
-
yield f"π Evaluation saved to: {eval_filepath.name}"
|
66 |
-
yield report
|
67 |
# configure logging
|
68 |
log_buffer = io.StringIO()
|
69 |
logging.basicConfig(
|
@@ -135,7 +124,7 @@ def load_saved_model_and_tokenizer():
|
|
135 |
return tokenizer, model
|
136 |
|
137 |
def evaluate_saved_model(progress=gr.Progress(track_tqdm=True)):
|
138 |
-
if
|
139 |
yield "β
Trained model found! Skipping training...\n"
|
140 |
else:
|
141 |
yield "β No trained model found. Please train the model first.\n"
|
@@ -218,7 +207,7 @@ model_name = "microsoft/deberta-v3-base"
|
|
218 |
|
219 |
def run_training(progress=gr.Progress(track_tqdm=True)):
|
220 |
log_queue = queue.Queue()
|
221 |
-
if
|
222 |
yield "β
Trained model found! Skipping training...\n"
|
223 |
for line in evaluate_saved_model():
|
224 |
yield line
|
|
|
41 |
tune_thresholds,
|
42 |
label_map,
|
43 |
label_row_soft,
|
44 |
+
AbuseDataset,
|
45 |
+
save_and_yield_eval
|
46 |
)
|
47 |
|
48 |
# Create evaluation results directory if it doesn't exist
|
|
|
53 |
LOG_FILE = PERSIST_DIR / "training.log"
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
# configure logging
|
57 |
log_buffer = io.StringIO()
|
58 |
logging.basicConfig(
|
|
|
124 |
return tokenizer, model
|
125 |
|
126 |
def evaluate_saved_model(progress=gr.Progress(track_tqdm=True)):
|
127 |
+
if MODEL_DIR.exists():
|
128 |
yield "β
Trained model found! Skipping training...\n"
|
129 |
else:
|
130 |
yield "β No trained model found. Please train the model first.\n"
|
|
|
207 |
|
208 |
def run_training(progress=gr.Progress(track_tqdm=True)):
|
209 |
log_queue = queue.Queue()
|
210 |
+
if MODEL_DIR.exists():
|
211 |
yield "β
Trained model found! Skipping training...\n"
|
212 |
for line in evaluate_saved_model():
|
213 |
yield line
|