from time import sleep import logging import sys import re import httpx from fastapi import FastAPI from fastapi.responses import JSONResponse, FileResponse from transformers import pipeline from phishing_datasets import submit_entry from url_tools import extract_urls, resolve_short_url, extract_domain_from_url from urlscan_client import UrlscanClient import requests from mnemonic_attack import find_confusable_brand from models.models import MessageModel, QueryModel, AppModel, InputModel, OutputModel, ReportModel from models.enums import ActionModel, SubActionModel app = FastAPI() urlscan = UrlscanClient() # Remove all handlers associated with the root logger object for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) logging.basicConfig( level=logging.INFO, format='%(levelname)s: %(asctime)s %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) pipe = pipeline(task="text-classification", model="mrm8488/bert-tiny-finetuned-sms-spam-detection") @app.get("/.well-known/apple-app-site-association", include_in_schema=False) def get_well_known_aasa(): return JSONResponse( content={ "messagefilter": { "apps": [ "X9NN3FSS3T.com.lela.Serenity.SerenityMessageFilterExtension", "X9NN3FSS3T.com.lela.Serenity" ] } }, media_type="application/json" ) @app.get("/robots.txt", include_in_schema=False) def get_robots_txt(): return FileResponse("robots.txt") @app.post("/predict") def predict(model: InputModel) -> OutputModel: sender = model.query.sender text = model.query.message.text logging.info(f"[{sender}] {text}") # Debug sleep pattern = r"^Sent from your Twilio trial account - sleep (\d+)$" match = re.search(pattern, text) if match: number_str = match.group(1) sleep_duration = int(number_str) logging.debug(f"[DEBUG SLEEP] Sleeping for {sleep_duration} seconds for sender {sender}") sleep(sleep_duration) return OutputModel(action=ActionModel.JUNK, sub_action=SubActionModel.NONE) # Debug category pattern = r"^Sent from your Twilio trial account - (junk|transaction|promotion)$" match = re.search(pattern, text) if match: category_str = match.group(1) logging.info(f"[DEBUG CATEGORY] Forced category: {category_str} for sender {sender}") match category_str: case 'junk': return OutputModel(action=ActionModel.JUNK, sub_action=SubActionModel.NONE) case 'transaction': return OutputModel(action=ActionModel.TRANSACTION, sub_action=SubActionModel.NONE) case 'promotion': return OutputModel(action=ActionModel.PROMOTION, sub_action=SubActionModel.NONE) # Brand usurpation detection using confusables confusable_brand = find_confusable_brand(text) if confusable_brand: logging.warning(f"[BRAND USURPATION] Confusable/homoglyph variant of brand '{confusable_brand}' detected in message. Classified as JUNK.") return OutputModel(action=ActionModel.JUNK, sub_action=SubActionModel.NONE) result = pipe(text) label = result[0]['label'] score = result[0]['score'] logging.info(f"[CLASSIFICATION] label={label} score={score}") if label == 'LABEL_0': score = 1 - score # Pattern for detecting an alphanumeric SenderID alphanumeric_sender_pattern = r'^[A-Za-z][A-Za-z0-9\-\.]{2,14}$' # Pattern for detecting a short code shorten_sender_pattern = r'^(?:3\d{4}|[4-8]\d{4})$' commercial_stop = False # Detection of commercial senders (short code or alphanumeric) if re.search(shorten_sender_pattern, sender): logging.info("[COMMERCIAL] Commercial sender detected (short code)") score = score * 0.7 elif re.match(alphanumeric_sender_pattern, sender): logging.info("[COMMERCIAL] Alphanumeric SenderID detected") score = score * 0.7 urls = extract_urls(text) if urls: logging.info(f"[URL] URLs found: {urls}") logging.info("[URL] Searching for previous scans") search_results = [urlscan.search(f"domain:{extract_domain_from_url(url)}") for url in urls] scan_results = [] for search_result in search_results: results = search_result.get('results', []) for result in results: result_uuid = result.get('_id', str) scan_result = urlscan.get_result(result_uuid) scan_results.append(scan_result) if not scan_results: logging.info("[URL] No previous scan found, launching a new scan...") scan_results = [urlscan.scan(url) for url in urls] for result in scan_results: overall = result.get('verdicts', {}).get('overall', {}) logging.info(f"[URLSCAN] Overall verdict: {overall}") if overall.get('hasVerdicts'): score = overall.get('score') logging.info(f"[URLSCAN] Verdict score: {score}") if 0 < overall.get('score'): score = 1.0 break elif overall.get('score') < 0: score = score * 0.9 else: logging.info(f"[URL] No URL found") score = score * 0.9 logging.info(f"[FINAL SCORE] {score}") action = ActionModel.NONE if score > 0.7: action=ActionModel.JUNK elif score > 0.5: if commercial_stop: action=ActionModel.PROMOTION else: action=ActionModel.JUNK logging.info(f"[FINAL ACTION] {action}") return OutputModel(action=action, sub_action=SubActionModel.NONE) @app.post("/report") def report(model: ReportModel): submit_entry(model.sender, model.message)