File size: 1,944 Bytes
cf43376 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import zipfile
import tempfile
from pathlib import Path
import torch
from transformers import DebertaV2Tokenizer, AutoModelForSequenceClassification
from train_abuse_model import (
MODEL_DIR,
device,
load_saved_model_and_tokenizer,
map_to_3_classes,
convert_to_label_strings
)
def run_prediction_pipeline(desc_input, chat_zip):
try:
# Start with the base input
merged_input = desc_input.strip()
# If a chat zip was uploaded
if chat_zip:
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(chat_zip.name, 'r') as zip_ref:
zip_ref.extractall(tmpdir)
chat_texts = []
for file in Path(tmpdir).glob("*.txt"):
with open(file, encoding="utf-8", errors="ignore") as f:
chat_texts.append(f.read())
full_chat = "\n".join(chat_texts)
# π§ MOCK summarization
summary = "[Mock summary of Hebrew WhatsApp chat...]"
# π MOCK translation
translated_summary = "[Translated summary in English]"
merged_input = f"{desc_input.strip()}\n\n[Summary]: {translated_summary}"
# Load classifier
tokenizer, model = load_saved_model_and_tokenizer()
inputs = tokenizer(merged_input, truncation=True, padding=True, max_length=512, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs).logits
probs = torch.sigmoid(outputs).cpu().numpy()
# Static threshold values (or load from config later)
best_low, best_high = 0.35, 0.65
pred_soft = map_to_3_classes(probs, best_low, best_high)
pred_str = convert_to_label_strings(pred_soft)
return merged_input, ", ".join(pred_str)
except Exception as e:
return f"β Prediction failed: {e}", ""
|