File size: 1,944 Bytes
cf43376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import zipfile
import tempfile
from pathlib import Path
import torch
from transformers import DebertaV2Tokenizer, AutoModelForSequenceClassification

from train_abuse_model import (
    MODEL_DIR,
    device,
    load_saved_model_and_tokenizer,
    map_to_3_classes,
    convert_to_label_strings
)

def run_prediction_pipeline(desc_input, chat_zip):
    try:
        # Start with the base input
        merged_input = desc_input.strip()

        # If a chat zip was uploaded
        if chat_zip:
            with tempfile.TemporaryDirectory() as tmpdir:
                with zipfile.ZipFile(chat_zip.name, 'r') as zip_ref:
                    zip_ref.extractall(tmpdir)
                chat_texts = []
                for file in Path(tmpdir).glob("*.txt"):
                    with open(file, encoding="utf-8", errors="ignore") as f:
                        chat_texts.append(f.read())
                full_chat = "\n".join(chat_texts)

                # 🧠 MOCK summarization
                summary = "[Mock summary of Hebrew WhatsApp chat...]"

                # 🌍 MOCK translation
                translated_summary = "[Translated summary in English]"

                merged_input = f"{desc_input.strip()}\n\n[Summary]: {translated_summary}"

        # Load classifier
        tokenizer, model = load_saved_model_and_tokenizer()
        inputs = tokenizer(merged_input, truncation=True, padding=True, max_length=512, return_tensors="pt").to(device)

        with torch.no_grad():
            outputs = model(**inputs).logits
            probs = torch.sigmoid(outputs).cpu().numpy()

        # Static threshold values (or load from config later)
        best_low, best_high = 0.35, 0.65
        pred_soft = map_to_3_classes(probs, best_low, best_high)
        pred_str = convert_to_label_strings(pred_soft)

        return merged_input, ", ".join(pred_str)

    except Exception as e:
        return f"❌ Prediction failed: {e}", ""