import streamlit as st
import torch
from transformers import BertForSequenceClassification, BertTokenizerFast
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import time
import pandas as pd
import base64
from PIL import Image
import io
# Set page configuration
st.set_page_config(
page_title="SMS Spam Guard",
page_icon="🛡️",
layout="wide",
initial_sidebar_state="expanded"
)
# Generate SafeTalk logo as base64 (blue shield with "ST" inside)
def create_logo():
from PIL import Image, ImageDraw, ImageFont
import io
import base64
# Create a new image with a transparent background
img = Image.new('RGBA', (200, 200), color=(0, 0, 0, 0))
draw = ImageDraw.Draw(img)
# Draw a shield shape
shield_color = (30, 58, 138) # Dark blue
# Shield outline
points = [(100, 10), (180, 50), (160, 170), (100, 190), (40, 170), (20, 50)]
draw.polygon(points, fill=shield_color)
# Try to load a font, or use default
try:
font = ImageFont.truetype("arial.ttf", 80)
except IOError:
font = ImageFont.load_default()
# Add "ST" text in white
draw.text((70, 60), "ST", fill=(255, 255, 255), font=font)
# Convert to base64 for embedding
buffered = io.BytesIO()
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
# Custom CSS for styling
st.markdown("""
""", unsafe_allow_html=True)
@st.cache_resource
def load_language_model():
"""Load the language detection model"""
model_name = "papluca/xlm-roberta-base-language-detection"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return tokenizer, model
@st.cache_resource
def load_spam_model():
"""Load the fine-tuned BERT spam detection model"""
model_path = "chjivan/final"
tokenizer = BertTokenizerFast.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)
return tokenizer, model
def detect_language(text, tokenizer, model):
"""Detect the language of the input text"""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
# Get predictions and convert to probabilities
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)[0]
# Get the predicted language and its probability
predicted_class_id = torch.argmax(probabilities).item()
predicted_language = model.config.id2label[predicted_class_id]
confidence = probabilities[predicted_class_id].item()
# Get top 3 languages with their probabilities
top_3_indices = torch.topk(probabilities, 3).indices.tolist()
top_3_probs = torch.topk(probabilities, 3).values.tolist()
top_3_langs = [(model.config.id2label[idx], prob) for idx, prob in zip(top_3_indices, top_3_probs)]
return predicted_language, confidence, top_3_langs
def classify_spam(text, tokenizer, model):
"""Classify the input text as spam or ham"""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
# Get predictions and convert to probabilities
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)[0]
# Get the predicted class and its probability (0: ham, 1: spam)
predicted_class_id = torch.argmax(probabilities).item()
confidence = probabilities[predicted_class_id].item()
is_spam = predicted_class_id == 1
return is_spam, confidence
# Generate and cache logo
logo_base64 = create_logo()
logo_html = f''
# Load both models
with st.spinner("Loading models... This may take a moment."):
lang_tokenizer, lang_model = load_language_model()
spam_tokenizer, spam_model = load_spam_model()
# App Header with logo
col1, col2 = st.columns([1, 5])
with col1:
st.markdown(logo_html, unsafe_allow_html=True)
with col2:
st.markdown('
智能短信垃圾过滤助手 by SafeTalk Communications Ltd.
', unsafe_allow_html=True) # Sidebar with st.sidebar: st.markdown(logo_html, unsafe_allow_html=True) st.markdown("### About SafeTalk") st.markdown("SafeTalk Communications Ltd. provides intelligent communication security solutions to protect users from spam and fraudulent messages.") st.markdown("#### Our Technology") st.markdown("- ✅ Advanced AI-powered spam detection") st.markdown("- 🌐 Multi-language support") st.markdown("- 🔒 Secure and private processing") st.markdown("- ⚡ Real-time analysis") st.markdown("---") st.markdown("### Sample Messages") if st.button("Sample Spam (English)"): st.session_state.sms_input = "URGENT: You have won a $1,000 Walmart gift card. Go to http://bit.ly/claim-prize to claim now before it expires!" if st.button("Sample Legitimate (English)"): st.session_state.sms_input = "Your Amazon package will be delivered today. Thanks for ordering from Amazon!" if st.button("Sample Message (French)"): st.session_state.sms_input = "Bonjour! Votre réservation pour le restaurant est confirmée pour ce soir à 20h. À bientôt!" if st.button("Sample Message (Spanish)"): st.session_state.sms_input = "Hola, tu cita médica está programada para mañana a las 10:00. Por favor llega 15 minutos antes." # Main Content st.markdown('