|
import os |
|
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache' |
|
|
|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
from flask_sock import Sock |
|
import uuid |
|
import time |
|
import requests |
|
from transformers import pipeline |
|
from Crypto.Cipher import AES |
|
from Crypto.Hash import SHA256 |
|
import base64 |
|
import threading |
|
|
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
sock = Sock(app) |
|
|
|
|
|
classifier = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-1") |
|
|
|
|
|
SESSIONS = {} |
|
|
|
|
|
SENSITIVE_LABELS = ["terrorism", "blackmail", "national security threat"] |
|
|
|
|
|
STORAGE_API = "https://mike23415-storage.hf.space/api/flag" |
|
|
|
def decrypt_message(encrypted_b64, password): |
|
try: |
|
|
|
pw_hash = SHA256.new(password.encode()).digest() |
|
|
|
encrypted = base64.b64decode(encrypted_b64) |
|
|
|
iv = encrypted[:12] |
|
ciphertext = encrypted[12:] |
|
|
|
cipher = AES.new(pw_hash, AES.MODE_GCM, nonce=iv) |
|
plaintext = cipher.decrypt(ciphertext).decode() |
|
return plaintext |
|
except Exception as e: |
|
print(f"Decryption failed: {e}") |
|
return None |
|
|
|
def flag_if_sensitive(decrypted_text, ip, session_id, role, encrypted_msg): |
|
if not decrypted_text: |
|
return |
|
|
|
result = classifier(decrypted_text, SENSITIVE_LABELS) |
|
scores = dict(zip(result["labels"], result["scores"])) |
|
for label, score in scores.items(): |
|
if score > 0.8: |
|
print(f"⚠️ FLAGGED: {label} with score {score}") |
|
|
|
SESSIONS[session_id]["flagged"] = True |
|
|
|
flagged_entry = { |
|
"encrypted_msg": encrypted_msg, |
|
"decrypted_msg": decrypted_text, |
|
"label": label, |
|
"score": score, |
|
"role": role, |
|
"ip": ip, |
|
"timestamp": time.time() |
|
} |
|
SESSIONS[session_id]["flagged_messages"].append(flagged_entry) |
|
break |
|
|
|
def log_flagged_session(session_id): |
|
if session_id not in SESSIONS or not SESSIONS[session_id]["flagged"]: |
|
return |
|
session = SESSIONS[session_id] |
|
payload = { |
|
"session_id": session_id, |
|
"created_at": session["created_at"], |
|
"messages": session["messages"], |
|
"unique_ips": list(set(msg["ip"] for msg in session["messages"])), |
|
"flagged_messages": session["flagged_messages"] |
|
} |
|
try: |
|
requests.post(STORAGE_API, json=payload, timeout=3) |
|
print(f"Logged flagged session {session_id}") |
|
except Exception as e: |
|
print(f"Failed to log session {session_id}: {e}") |
|
|
|
def cleanup_session(session_id): |
|
if session_id in SESSIONS: |
|
log_flagged_session(session_id) |
|
del SESSIONS[session_id] |
|
print(f"Deleted session {session_id}") |
|
|
|
@app.route("/api/create_chat", methods=["POST"]) |
|
def create_chat(): |
|
data = request.get_json() |
|
password = data.get("password", "default") |
|
session_id = str(uuid.uuid4()) |
|
SESSIONS[session_id] = { |
|
"password": password, |
|
"created_at": time.time(), |
|
"messages": [], |
|
"flagged": False, |
|
"flagged_messages": [], |
|
"connections": [] |
|
} |
|
|
|
threading.Timer(900, cleanup_session, args=[session_id]).start() |
|
short_id = session_id[:8] |
|
short_url = f"https://{request.host}/s/{short_id}" |
|
return jsonify({"session_id": session_id, "short_id": short_id, "short_url": short_url}) |
|
|
|
@sock.route('/ws/<session_id>') |
|
def chat(ws, session_id): |
|
ip = request.remote_addr or "unknown" |
|
if session_id not in SESSIONS: |
|
ws.send('{"type": "error", "message": "Session not found"}') |
|
ws.close() |
|
return |
|
|
|
|
|
join_index = sum(1 for msg in SESSIONS[session_id]["messages"] if msg["role"].startswith("Receiver")) + 1 |
|
role = "Sender" if len(SESSIONS[session_id]["messages"]) == 0 else f"Receiver {join_index}" |
|
SESSIONS[session_id]["connections"].append(ws) |
|
|
|
try: |
|
while True: |
|
msg = ws.receive() |
|
if msg is None: |
|
break |
|
entry = { |
|
"role": role, |
|
"encrypted_msg": msg, |
|
"ip": ip, |
|
"timestamp": time.time() |
|
} |
|
SESSIONS[session_id]["messages"].append(entry) |
|
|
|
|
|
decrypted_text = decrypt_message(msg, SESSIONS[session_id]["password"]) |
|
flag_if_sensitive(decrypted_text, ip, session_id, role, msg) |
|
|
|
|
|
for conn in SESSIONS[session_id]["connections"]: |
|
try: |
|
conn.send(f'{{"role": "{role}", "encrypted_msg": "{msg}"}}') |
|
except: |
|
continue |
|
except Exception as e: |
|
print(f"WebSocket error: {e}") |
|
finally: |
|
if ws in SESSIONS[session_id]["connections"]: |
|
SESSIONS[session_id]["connections"].remove(ws) |
|
|
|
@app.route("/") |
|
def root(): |
|
return "Real-time AI chat backend is running." |
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", port=7860) |