import zipfile import tempfile from pathlib import Path import torch from transformers import DebertaV2Tokenizer, AutoModelForSequenceClassification from train_abuse_model import ( MODEL_DIR, device, load_saved_model_and_tokenizer, map_to_3_classes, convert_to_label_strings ) def run_prediction_pipeline(desc_input, chat_zip): try: # Start with the base input merged_input = desc_input.strip() # If a chat zip was uploaded if chat_zip: with tempfile.TemporaryDirectory() as tmpdir: with zipfile.ZipFile(chat_zip.name, 'r') as zip_ref: zip_ref.extractall(tmpdir) chat_texts = [] for file in Path(tmpdir).glob("*.txt"): with open(file, encoding="utf-8", errors="ignore") as f: chat_texts.append(f.read()) full_chat = "\n".join(chat_texts) # 🧠 MOCK summarization summary = "[Mock summary of Hebrew WhatsApp chat...]" # 🌍 MOCK translation translated_summary = "[Translated summary in English]" merged_input = f"{desc_input.strip()}\n\n[Summary]: {translated_summary}" # Load classifier tokenizer, model = load_saved_model_and_tokenizer() inputs = tokenizer(merged_input, truncation=True, padding=True, max_length=512, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs).logits probs = torch.sigmoid(outputs).cpu().numpy() # Static threshold values (or load from config later) best_low, best_high = 0.35, 0.65 pred_soft = map_to_3_classes(probs, best_low, best_high) pred_str = convert_to_label_strings(pred_soft) return merged_input, ", ".join(pred_str) except Exception as e: return f"❌ Prediction failed: {e}", ""