|
import zipfile |
|
import tempfile |
|
from pathlib import Path |
|
import torch |
|
from transformers import DebertaV2Tokenizer, AutoModelForSequenceClassification |
|
|
|
from train_abuse_model import ( |
|
MODEL_DIR, |
|
device, |
|
load_saved_model_and_tokenizer, |
|
map_to_3_classes, |
|
convert_to_label_strings |
|
) |
|
|
|
def run_prediction_pipeline(desc_input, chat_zip): |
|
try: |
|
|
|
merged_input = desc_input.strip() |
|
|
|
|
|
if chat_zip: |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
with zipfile.ZipFile(chat_zip.name, 'r') as zip_ref: |
|
zip_ref.extractall(tmpdir) |
|
chat_texts = [] |
|
for file in Path(tmpdir).glob("*.txt"): |
|
with open(file, encoding="utf-8", errors="ignore") as f: |
|
chat_texts.append(f.read()) |
|
full_chat = "\n".join(chat_texts) |
|
|
|
|
|
summary = "[Mock summary of Hebrew WhatsApp chat...]" |
|
|
|
|
|
translated_summary = "[Translated summary in English]" |
|
|
|
merged_input = f"{desc_input.strip()}\n\n[Summary]: {translated_summary}" |
|
|
|
|
|
tokenizer, model = load_saved_model_and_tokenizer() |
|
inputs = tokenizer(merged_input, truncation=True, padding=True, max_length=512, return_tensors="pt").to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs).logits |
|
probs = torch.sigmoid(outputs).cpu().numpy() |
|
|
|
|
|
best_low, best_high = 0.35, 0.65 |
|
pred_soft = map_to_3_classes(probs, best_low, best_high) |
|
pred_str = convert_to_label_strings(pred_soft) |
|
|
|
return merged_input, ", ".join(pred_str) |
|
|
|
except Exception as e: |
|
return f"β Prediction failed: {e}", "" |
|
|