import streamlit as st import torch from transformers import BertForSequenceClassification, BertTokenizerFast from transformers import AutoModelForSequenceClassification, AutoTokenizer import time import pandas as pd import base64 from PIL import Image import io # Set page configuration st.set_page_config( page_title="SMS Spam Guard", page_icon="🛡️", layout="wide", initial_sidebar_state="expanded" ) # Generate SafeTalk logo as base64 (blue shield with "ST" inside) def create_logo(): from PIL import Image, ImageDraw, ImageFont import io import base64 # Create a new image with a transparent background img = Image.new('RGBA', (200, 200), color=(0, 0, 0, 0)) draw = ImageDraw.Draw(img) # Draw a shield shape shield_color = (30, 58, 138) # Dark blue # Shield outline points = [(100, 10), (180, 50), (160, 170), (100, 190), (40, 170), (20, 50)] draw.polygon(points, fill=shield_color) # Try to load a font, or use default try: font = ImageFont.truetype("arial.ttf", 80) except IOError: font = ImageFont.load_default() # Add "ST" text in white draw.text((70, 60), "ST", fill=(255, 255, 255), font=font) # Convert to base64 for embedding buffered = io.BytesIO() img.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode() # Custom CSS for styling st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_language_model(): """Load the language detection model""" model_name = "papluca/xlm-roberta-base-language-detection" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) return tokenizer, model @st.cache_resource def load_spam_model(): """Load the fine-tuned BERT spam detection model""" model_path = "chjivan/final" tokenizer = BertTokenizerFast.from_pretrained(model_path) model = BertForSequenceClassification.from_pretrained(model_path) return tokenizer, model def detect_language(text, tokenizer, model): """Detect the language of the input text""" inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) # Get predictions and convert to probabilities logits = outputs.logits probabilities = torch.softmax(logits, dim=1)[0] # Get the predicted language and its probability predicted_class_id = torch.argmax(probabilities).item() predicted_language = model.config.id2label[predicted_class_id] confidence = probabilities[predicted_class_id].item() # Get top 3 languages with their probabilities top_3_indices = torch.topk(probabilities, 3).indices.tolist() top_3_probs = torch.topk(probabilities, 3).values.tolist() top_3_langs = [(model.config.id2label[idx], prob) for idx, prob in zip(top_3_indices, top_3_probs)] return predicted_language, confidence, top_3_langs def classify_spam(text, tokenizer, model): """Classify the input text as spam or ham""" inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) with torch.no_grad(): outputs = model(**inputs) # Get predictions and convert to probabilities logits = outputs.logits probabilities = torch.softmax(logits, dim=1)[0] # Get the predicted class and its probability (0: ham, 1: spam) predicted_class_id = torch.argmax(probabilities).item() confidence = probabilities[predicted_class_id].item() is_spam = predicted_class_id == 1 return is_spam, confidence # Generate and cache logo logo_base64 = create_logo() logo_html = f'' # Load both models with st.spinner("Loading models... This may take a moment."): lang_tokenizer, lang_model = load_language_model() spam_tokenizer, spam_model = load_spam_model() # App Header with logo col1, col2 = st.columns([1, 5]) with col1: st.markdown(logo_html, unsafe_allow_html=True) with col2: st.markdown('

SMS Spam Guard

', unsafe_allow_html=True) st.markdown('

智能短信垃圾过滤助手 by SafeTalk Communications Ltd.

', unsafe_allow_html=True) # Sidebar with st.sidebar: st.markdown(logo_html, unsafe_allow_html=True) st.markdown("### About SafeTalk") st.markdown("SafeTalk Communications Ltd. provides intelligent communication security solutions to protect users from spam and fraudulent messages.") st.markdown("#### Our Technology") st.markdown("- ✅ Advanced AI-powered spam detection") st.markdown("- 🌐 Multi-language support") st.markdown("- 🔒 Secure and private processing") st.markdown("- ⚡ Real-time analysis") st.markdown("---") st.markdown("### Sample Messages") if st.button("Sample Spam (English)"): st.session_state.sms_input = "URGENT: You have won a $1,000 Walmart gift card. Go to http://bit.ly/claim-prize to claim now before it expires!" if st.button("Sample Legitimate (English)"): st.session_state.sms_input = "Your Amazon package will be delivered today. Thanks for ordering from Amazon!" if st.button("Sample Message (French)"): st.session_state.sms_input = "Bonjour! Votre réservation pour le restaurant est confirmée pour ce soir à 20h. À bientôt!" if st.button("Sample Message (Spanish)"): st.session_state.sms_input = "Hola, tu cita médica está programada para mañana a las 10:00. Por favor llega 15 minutos antes." # Main Content st.markdown('
', unsafe_allow_html=True) # Input form sms_input = st.text_area( "Enter the SMS message to analyze:", value=st.session_state.get("sms_input", ""), height=100, key="sms_input", help="Enter the SMS message you want to analyze for spam" ) analyze_button = st.button("📱 Analyze Message", use_container_width=True) st.markdown('
', unsafe_allow_html=True) # Process input and display results if analyze_button and sms_input: with st.spinner("Analyzing message..."): # Step 1: Language Detection lang_start_time = time.time() lang_code, lang_confidence, top_langs = detect_language(sms_input, lang_tokenizer, lang_model) lang_time = time.time() - lang_start_time # Create mapping for full language names lang_names = { "ar": "Arabic", "bg": "Bulgarian", "de": "German", "el": "Greek", "en": "English", "es": "Spanish", "fr": "French", "hi": "Hindi", "it": "Italian", "ja": "Japanese", "nl": "Dutch", "pl": "Polish", "pt": "Portuguese", "ru": "Russian", "sw": "Swahili", "th": "Thai", "tr": "Turkish", "ur": "Urdu", "vi": "Vietnamese", "zh": "Chinese" } lang_name = lang_names.get(lang_code, lang_code) # Step 2: Spam Classification spam_start_time = time.time() is_spam, spam_confidence = classify_spam(sms_input, spam_tokenizer, spam_model) spam_time = time.time() - spam_start_time # Display Language Detection Results st.markdown("### Analysis Results") col1, col2 = st.columns(2) with col1: st.markdown("#### 📊 Language Detection") st.markdown(f'
', unsafe_allow_html=True) st.markdown(f'{lang_name} Detected with {lang_confidence:.1%} confidence', unsafe_allow_html=True) # Display top 3 languages st.markdown("##### Top language probabilities:") for lang_code, prob in top_langs: lang_full = lang_names.get(lang_code, lang_code) st.markdown(f"- {lang_full}: {prob:.1%}") st.markdown(f"⏱️ Processing time: {lang_time:.3f} seconds") st.markdown('
', unsafe_allow_html=True) with col2: st.markdown("#### 🔍 Spam Detection") if is_spam: st.markdown(f'
', unsafe_allow_html=True) st.markdown(f"⚠️ **SPAM DETECTED** with {spam_confidence:.1%} confidence") st.markdown("This message appears to be spam and potentially harmful.") else: st.markdown(f'
', unsafe_allow_html=True) st.markdown(f"✅ **LEGITIMATE MESSAGE** with {spam_confidence:.1%} confidence") st.markdown("This message appears to be legitimate.") st.markdown(f"⏱️ Processing time: {spam_time:.3f} seconds") st.markdown('
', unsafe_allow_html=True) # Summary and Recommendations st.markdown("### 📋 Summary & Recommendations") if is_spam: st.warning("📵 **Recommended Action**: This message should be blocked or moved to spam folder.") st.markdown(""" **Why this is likely spam:** - Contains suspicious language patterns - May include urgent calls to action - Could contain unsolicited offers """) else: st.success("✅ **Recommended Action**: This message can be delivered to the inbox.") # Chart for visualization st.markdown("### 📈 Confidence Visualization") chart_data = pd.DataFrame({ 'Task': ['Language Detection', 'Spam Classification'], 'Confidence': [lang_confidence, spam_confidence if is_spam else 1-spam_confidence] }) st.bar_chart(chart_data.set_index('Task')) # Footer st.markdown('', unsafe_allow_html=True)