Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import HTMLResponse | |
from pydantic import BaseModel | |
import joblib | |
import pandas as pd | |
import numpy as np | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.naive_bayes import MultinomialNB | |
from sklearn.pipeline import Pipeline | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import classification_report, accuracy_score | |
import re | |
import os | |
from typing import List, Dict, Any | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Email Attachment Classifier API", | |
description="API to classify whether an email has attachments or not using Naive Bayes", | |
version="1.0.0", | |
docs_url="/docs", | |
redoc_url="/redoc" | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Pydantic models | |
class EmailInput(BaseModel): | |
message: str | |
class EmailBatchInput(BaseModel): | |
messages: List[str] | |
class PredictionResponse(BaseModel): | |
message: str | |
prediction: int | |
prediction_label: str | |
confidence: float | |
probabilities: Dict[str, float] | |
class BatchPredictionResponse(BaseModel): | |
predictions: List[PredictionResponse] | |
class ModelInfo(BaseModel): | |
model_type: str | |
accuracy: float | |
feature_count: int | |
training_samples: int | |
# Global variables | |
model_pipeline = None | |
model_info = None | |
def preprocess_text(text: str) -> str: | |
"""Preprocess email text""" | |
# Convert to lowercase | |
text = text.lower() | |
# Remove extra whitespace | |
text = re.sub(r'\s+', ' ', text) | |
# Remove special characters but keep basic punctuation | |
text = re.sub(r'[^\w\s,.\-!?]', ' ', text) | |
return text.strip() | |
def load_and_train_model(): | |
"""Load data and train the Naive Bayes model""" | |
global model_pipeline, model_info | |
try: | |
# Load the dataset (assuming it's in the same directory) | |
if os.path.exists('Synthetic_Email_Dataset.csv'): | |
df = pd.read_csv('Synthetic_Email_Dataset.csv') | |
else: | |
logger.warning("Dataset file not found, creating sample data") | |
# Create sample data for demonstration | |
sample_data = { | |
'label': [0, 1, 0, 1] * 100, | |
'message': [ | |
"Hello, You asked for it, so here is the notes. Warm wishes, David", | |
"Good morning, Just sharing the meeting agenda as requested. Cheers, Anna", | |
"Dear team, As discussed, I'm sending the manual. Regards, Emily", | |
"Hi all, Please find attached the project plan. Thanks, Michael" | |
] * 100 | |
} | |
df = pd.DataFrame(sample_data) | |
# Preprocess messages | |
df['processed_message'] = df['message'].apply(preprocess_text) | |
# Split data | |
X = df['processed_message'] | |
y = df['label'] | |
X_train, X_test, y_train, y_test = train_test_split( | |
X, y, test_size=0.2, random_state=42, stratify=y | |
) | |
# Create pipeline | |
model_pipeline = Pipeline([ | |
('tfidf', TfidfVectorizer( | |
max_features=1000, | |
ngram_range=(1, 2), | |
stop_words='english', | |
lowercase=True, | |
min_df=1, | |
max_df=0.95 | |
)), | |
('classifier', MultinomialNB(alpha=1.0)) | |
]) | |
# Train model | |
logger.info("Training Naive Bayes model...") | |
model_pipeline.fit(X_train, y_train) | |
# Evaluate model | |
y_pred = model_pipeline.predict(X_test) | |
accuracy = accuracy_score(y_test, y_pred) | |
# Store model info | |
model_info = ModelInfo( | |
model_type="Multinomial Naive Bayes", | |
accuracy=round(accuracy, 4), | |
feature_count=model_pipeline.named_steps['tfidf'].vocabulary_.__len__(), | |
training_samples=len(X_train) | |
) | |
logger.info(f"Model trained successfully with accuracy: {accuracy:.4f}") | |
logger.info(f"Feature count: {model_info.feature_count}") | |
# Save model | |
joblib.dump(model_pipeline, 'email_classifier_model.pkl') | |
logger.info("Model saved successfully") | |
return True | |
except Exception as e: | |
logger.error(f"Error in training model: {str(e)}") | |
return False | |
def load_pretrained_model(): | |
"""Load pretrained model if available""" | |
global model_pipeline, model_info | |
try: | |
if os.path.exists('email_classifier_model.pkl'): | |
model_pipeline = joblib.load('email_classifier_model.pkl') | |
logger.info("Pretrained model loaded successfully") | |
# Set default model info if not available | |
if model_info is None: | |
model_info = ModelInfo( | |
model_type="Multinomial Naive Bayes", | |
accuracy=0.92, # Default value | |
feature_count=len(model_pipeline.named_steps['tfidf'].vocabulary_), | |
training_samples=320 # Default value | |
) | |
return True | |
except Exception as e: | |
logger.error(f"Error loading pretrained model: {str(e)}") | |
return False | |
async def startup_event(): | |
"""Initialize model on startup""" | |
logger.info("Starting Email Classifier API...") | |
# Try to load pretrained model first | |
if not load_pretrained_model(): | |
# If no pretrained model, train new one | |
if not load_and_train_model(): | |
logger.error("Failed to initialize model") | |
async def root(): | |
"""Root endpoint with API documentation""" | |
html_content = """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Email Attachment Classifier API</title> | |
<style> | |
body { font-family: Arial, sans-serif; margin: 40px; } | |
.header { color: #2c3e50; } | |
.endpoint { background-color: #f8f9fa; padding: 15px; margin: 10px 0; border-radius: 5px; } | |
.method { color: #27ae60; font-weight: bold; } | |
code { background-color: #e9ecef; padding: 2px 4px; border-radius: 3px; } | |
</style> | |
</head> | |
<body> | |
<h1 class="header">📧 Email Attachment Classifier API</h1> | |
<p>This API classifies whether an email message indicates an attachment or not using Naive Bayes classifier.</p> | |
<h2>Available Endpoints:</h2> | |
<div class="endpoint"> | |
<h3><span class="method">GET</span> /info</h3> | |
<p>Get model information and statistics</p> | |
</div> | |
<div class="endpoint"> | |
<h3><span class="method">POST</span> /predict</h3> | |
<p>Predict single email message</p> | |
<p><strong>Body:</strong> <code>{"message": "Your email content here"}</code></p> | |
</div> | |
<div class="endpoint"> | |
<h3><span class="method">POST</span> /predict-batch</h3> | |
<p>Predict multiple email messages</p> | |
<p><strong>Body:</strong> <code>{"messages": ["Email 1", "Email 2", ...]}</code></p> | |
</div> | |
<div class="endpoint"> | |
<h3><span class="method">GET</span> /health</h3> | |
<p>Check API health status</p> | |
</div> | |
<h2>Interactive Documentation:</h2> | |
<p>Visit <a href="/docs">/docs</a> for Swagger UI or <a href="/redoc">/redoc</a> for ReDoc</p> | |
<h2>Labels:</h2> | |
<ul> | |
<li><strong>0:</strong> No attachment mentioned</li> | |
<li><strong>1:</strong> Attachment mentioned</li> | |
</ul> | |
</body> | |
</html> | |
""" | |
return HTMLResponse(content=html_content, status_code=200) | |
async def health_check(): | |
"""Health check endpoint""" | |
if model_pipeline is None: | |
return {"status": "unhealthy", "message": "Model not loaded"} | |
return {"status": "healthy", "message": "API is running"} | |
async def get_model_info(): | |
"""Get model information""" | |
if model_info is None: | |
raise HTTPException(status_code=503, detail="Model not initialized") | |
return model_info | |
async def predict_single(email: EmailInput): | |
"""Predict single email message""" | |
if model_pipeline is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
try: | |
# Preprocess input | |
processed_message = preprocess_text(email.message) | |
# Make prediction | |
prediction = model_pipeline.predict([processed_message])[0] | |
probabilities = model_pipeline.predict_proba([processed_message])[0] | |
# Prepare response | |
prediction_label = "Has attachment" if prediction == 1 else "No attachment" | |
confidence = float(max(probabilities)) | |
prob_dict = { | |
"no_attachment": float(probabilities[0]), | |
"has_attachment": float(probabilities[1]) | |
} | |
return PredictionResponse( | |
message=email.message, | |
prediction=int(prediction), | |
prediction_label=prediction_label, | |
confidence=confidence, | |
probabilities=prob_dict | |
) | |
except Exception as e: | |
logger.error(f"Prediction error: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
async def predict_batch(emails: EmailBatchInput): | |
"""Predict multiple email messages""" | |
if model_pipeline is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
if len(emails.messages) > 100: | |
raise HTTPException(status_code=400, detail="Maximum 100 messages per batch") | |
try: | |
predictions = [] | |
# Preprocess all messages | |
processed_messages = [preprocess_text(msg) for msg in emails.messages] | |
# Make batch predictions | |
batch_predictions = model_pipeline.predict(processed_messages) | |
batch_probabilities = model_pipeline.predict_proba(processed_messages) | |
# Prepare responses | |
for i, (message, prediction, probabilities) in enumerate( | |
zip(emails.messages, batch_predictions, batch_probabilities) | |
): | |
prediction_label = "Has attachment" if prediction == 1 else "No attachment" | |
confidence = float(max(probabilities)) | |
prob_dict = { | |
"no_attachment": float(probabilities[0]), | |
"has_attachment": float(probabilities[1]) | |
} | |
predictions.append(PredictionResponse( | |
message=message, | |
prediction=int(prediction), | |
prediction_label=prediction_label, | |
confidence=confidence, | |
probabilities=prob_dict | |
)) | |
return BatchPredictionResponse(predictions=predictions) | |
except Exception as e: | |
logger.error(f"Batch prediction error: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |