safe-talk / predict_pipeline.py
rshakked's picture
fix: include missing predict_pipeline.py in repo with run_prediction_pipeline()
cf43376
import zipfile
import tempfile
from pathlib import Path
import torch
from transformers import DebertaV2Tokenizer, AutoModelForSequenceClassification
from train_abuse_model import (
MODEL_DIR,
device,
load_saved_model_and_tokenizer,
map_to_3_classes,
convert_to_label_strings
)
def run_prediction_pipeline(desc_input, chat_zip):
try:
# Start with the base input
merged_input = desc_input.strip()
# If a chat zip was uploaded
if chat_zip:
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(chat_zip.name, 'r') as zip_ref:
zip_ref.extractall(tmpdir)
chat_texts = []
for file in Path(tmpdir).glob("*.txt"):
with open(file, encoding="utf-8", errors="ignore") as f:
chat_texts.append(f.read())
full_chat = "\n".join(chat_texts)
# 🧠 MOCK summarization
summary = "[Mock summary of Hebrew WhatsApp chat...]"
# 🌍 MOCK translation
translated_summary = "[Translated summary in English]"
merged_input = f"{desc_input.strip()}\n\n[Summary]: {translated_summary}"
# Load classifier
tokenizer, model = load_saved_model_and_tokenizer()
inputs = tokenizer(merged_input, truncation=True, padding=True, max_length=512, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs).logits
probs = torch.sigmoid(outputs).cpu().numpy()
# Static threshold values (or load from config later)
best_low, best_high = 0.35, 0.65
pred_soft = map_to_3_classes(probs, best_low, best_high)
pred_str = convert_to_label_strings(pred_soft)
return merged_input, ", ".join(pred_str)
except Exception as e:
return f"❌ Prediction failed: {e}", ""