from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from pydantic import BaseModel import joblib import pandas as pd import numpy as np from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.naive_bayes import MultinomialNB from sklearn.pipeline import Pipeline from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report, accuracy_score import re import os from typing import List, Dict, Any import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Email Attachment Classifier API", description="API to classify whether an email has attachments or not using Naive Bayes", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Pydantic models class EmailInput(BaseModel): message: str class EmailBatchInput(BaseModel): messages: List[str] class PredictionResponse(BaseModel): message: str prediction: int prediction_label: str confidence: float probabilities: Dict[str, float] class BatchPredictionResponse(BaseModel): predictions: List[PredictionResponse] class ModelInfo(BaseModel): model_type: str accuracy: float feature_count: int training_samples: int # Global variables model_pipeline = None model_info = None def preprocess_text(text: str) -> str: """Preprocess email text""" # Convert to lowercase text = text.lower() # Remove extra whitespace text = re.sub(r'\s+', ' ', text) # Remove special characters but keep basic punctuation text = re.sub(r'[^\w\s,.\-!?]', ' ', text) return text.strip() def load_and_train_model(): """Load data and train the Naive Bayes model""" global model_pipeline, model_info try: # Load the dataset (assuming it's in the same directory) if os.path.exists('Synthetic_Email_Dataset.csv'): df = pd.read_csv('Synthetic_Email_Dataset.csv') else: logger.warning("Dataset file not found, creating sample data") # Create sample data for demonstration sample_data = { 'label': [0, 1, 0, 1] * 100, 'message': [ "Hello, You asked for it, so here is the notes. Warm wishes, David", "Good morning, Just sharing the meeting agenda as requested. Cheers, Anna", "Dear team, As discussed, I'm sending the manual. Regards, Emily", "Hi all, Please find attached the project plan. Thanks, Michael" ] * 100 } df = pd.DataFrame(sample_data) # Preprocess messages df['processed_message'] = df['message'].apply(preprocess_text) # Split data X = df['processed_message'] y = df['label'] X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42, stratify=y ) # Create pipeline model_pipeline = Pipeline([ ('tfidf', TfidfVectorizer( max_features=1000, ngram_range=(1, 2), stop_words='english', lowercase=True, min_df=1, max_df=0.95 )), ('classifier', MultinomialNB(alpha=1.0)) ]) # Train model logger.info("Training Naive Bayes model...") model_pipeline.fit(X_train, y_train) # Evaluate model y_pred = model_pipeline.predict(X_test) accuracy = accuracy_score(y_test, y_pred) # Store model info model_info = ModelInfo( model_type="Multinomial Naive Bayes", accuracy=round(accuracy, 4), feature_count=model_pipeline.named_steps['tfidf'].vocabulary_.__len__(), training_samples=len(X_train) ) logger.info(f"Model trained successfully with accuracy: {accuracy:.4f}") logger.info(f"Feature count: {model_info.feature_count}") # Save model joblib.dump(model_pipeline, 'email_classifier_model.pkl') logger.info("Model saved successfully") return True except Exception as e: logger.error(f"Error in training model: {str(e)}") return False def load_pretrained_model(): """Load pretrained model if available""" global model_pipeline, model_info try: if os.path.exists('email_classifier_model.pkl'): model_pipeline = joblib.load('email_classifier_model.pkl') logger.info("Pretrained model loaded successfully") # Set default model info if not available if model_info is None: model_info = ModelInfo( model_type="Multinomial Naive Bayes", accuracy=0.92, # Default value feature_count=len(model_pipeline.named_steps['tfidf'].vocabulary_), training_samples=320 # Default value ) return True except Exception as e: logger.error(f"Error loading pretrained model: {str(e)}") return False @app.on_event("startup") async def startup_event(): """Initialize model on startup""" logger.info("Starting Email Classifier API...") # Try to load pretrained model first if not load_pretrained_model(): # If no pretrained model, train new one if not load_and_train_model(): logger.error("Failed to initialize model") @app.get("/", response_class=HTMLResponse) async def root(): """Root endpoint with API documentation""" html_content = """ Email Attachment Classifier API

📧 Email Attachment Classifier API

This API classifies whether an email message indicates an attachment or not using Naive Bayes classifier.

Available Endpoints:

GET /info

Get model information and statistics

POST /predict

Predict single email message

Body: {"message": "Your email content here"}

POST /predict-batch

Predict multiple email messages

Body: {"messages": ["Email 1", "Email 2", ...]}

GET /health

Check API health status

Interactive Documentation:

Visit /docs for Swagger UI or /redoc for ReDoc

Labels:

""" return HTMLResponse(content=html_content, status_code=200) @app.get("/health") async def health_check(): """Health check endpoint""" if model_pipeline is None: return {"status": "unhealthy", "message": "Model not loaded"} return {"status": "healthy", "message": "API is running"} @app.get("/info", response_model=ModelInfo) async def get_model_info(): """Get model information""" if model_info is None: raise HTTPException(status_code=503, detail="Model not initialized") return model_info @app.post("/predict", response_model=PredictionResponse) async def predict_single(email: EmailInput): """Predict single email message""" if model_pipeline is None: raise HTTPException(status_code=503, detail="Model not loaded") try: # Preprocess input processed_message = preprocess_text(email.message) # Make prediction prediction = model_pipeline.predict([processed_message])[0] probabilities = model_pipeline.predict_proba([processed_message])[0] # Prepare response prediction_label = "Has attachment" if prediction == 1 else "No attachment" confidence = float(max(probabilities)) prob_dict = { "no_attachment": float(probabilities[0]), "has_attachment": float(probabilities[1]) } return PredictionResponse( message=email.message, prediction=int(prediction), prediction_label=prediction_label, confidence=confidence, probabilities=prob_dict ) except Exception as e: logger.error(f"Prediction error: {str(e)}") raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") @app.post("/predict-batch", response_model=BatchPredictionResponse) async def predict_batch(emails: EmailBatchInput): """Predict multiple email messages""" if model_pipeline is None: raise HTTPException(status_code=503, detail="Model not loaded") if len(emails.messages) > 100: raise HTTPException(status_code=400, detail="Maximum 100 messages per batch") try: predictions = [] # Preprocess all messages processed_messages = [preprocess_text(msg) for msg in emails.messages] # Make batch predictions batch_predictions = model_pipeline.predict(processed_messages) batch_probabilities = model_pipeline.predict_proba(processed_messages) # Prepare responses for i, (message, prediction, probabilities) in enumerate( zip(emails.messages, batch_predictions, batch_probabilities) ): prediction_label = "Has attachment" if prediction == 1 else "No attachment" confidence = float(max(probabilities)) prob_dict = { "no_attachment": float(probabilities[0]), "has_attachment": float(probabilities[1]) } predictions.append(PredictionResponse( message=message, prediction=int(prediction), prediction_label=prediction_label, confidence=confidence, probabilities=prob_dict )) return BatchPredictionResponse(predictions=predictions) except Exception as e: logger.error(f"Batch prediction error: {str(e)}") raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)