from fastapi import FastAPI, UploadFile, Form, HTTPException, Depends, status, File from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse import pandas as pd import numpy as np from sklearn.naive_bayes import CategoricalNB from sklearn.preprocessing import LabelEncoder from sklearn.model_selection import train_test_split from sklearn.metrics import confusion_matrix import json import io from typing import Dict, List, Optional, Any from pydantic import BaseModel, Field import matplotlib.pyplot as plt import seaborn as sns from fastapi.encoders import jsonable_encoder app = FastAPI( title="Categorical Naive Bayes API", description="API for uploading CSVs, training a Categorical Naive Bayes model, and making predictions.", version="1.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class TrainOptions(BaseModel): target_column: str = Field(..., description="The name of the target column.") feature_columns: List[str] = Field(..., description="List of feature column names.") class PredictionFeatures(BaseModel): features: Dict[str, str] = Field(..., description="Dictionary of feature values for prediction.") class UploadResponse(BaseModel): message: str columns: List[str] column_types: Dict[str, str] unique_values: Dict[str, List[Any]] row_count: int class TrainResponse(BaseModel): message: str accuracy: float target_classes: List[str] class PredictResponse(BaseModel): prediction: str probabilities: Dict[str, float] class ModelState: def __init__(self): self.model: Optional[CategoricalNB] = None self.feature_encoders: Dict[str, LabelEncoder] = {} self.target_encoder: Optional[LabelEncoder] = None self.X_test: Optional[pd.DataFrame] = None self.y_test: Optional[np.ndarray] = None model_state = ModelState() def get_model_state(): return model_state @app.get("/api/health", tags=["Health"], summary="Health Check", response_model=Dict[str, str]) async def health_check(): """Check API health.""" return {"status": "healthy"} @app.post("/api/upload", tags=["Data"], summary="Upload CSV File", response_model=UploadResponse, status_code=status.HTTP_200_OK) async def upload_csv( file: UploadFile = File(..., description="CSV file to upload") ) -> UploadResponse: """Upload a CSV file and get metadata about its columns.""" if not file.filename or not file.filename.lower().endswith('.csv'): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Only CSV files are allowed" ) try: contents = await file.read() # Check if the file content is valid if len(contents) == 0: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Uploaded file is empty" ) # Try to parse the CSV try: df = pd.read_csv(io.StringIO(contents.decode('utf-8'))) except UnicodeDecodeError: # Try another encoding if UTF-8 fails try: df = pd.read_csv(io.StringIO(contents.decode('latin-1'))) except Exception: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Unable to decode CSV file. Please ensure it's properly formatted." ) except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Error parsing CSV: {str(e)}" ) if df.empty: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="CSV file contains no data" ) # Process the data columns = df.columns.tolist() column_types = {col: str(df[col].dtype) for col in columns} # Limit the number of unique values to prevent excessive response sizes unique_values = {} for col in columns: unique_vals = df[col].unique().tolist() # Limit to 100 values max to prevent excessive response size if len(unique_vals) > 100: unique_values[col] = unique_vals[:100] + ["... (truncated)"] else: unique_values[col] = unique_vals # Convert NumPy objects to Python native types for col, values in unique_values.items(): unique_values[col] = [v.item() if isinstance(v, np.generic) else v for v in values] return UploadResponse( message="File uploaded and processed successfully", columns=columns, column_types=column_types, unique_values=unique_values, row_count=len(df) ) except HTTPException: # Re-raise HTTP exceptions raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An unexpected error occurred: {str(e)}" ) @app.post("/api/train", tags=["Model"], summary="Train Model", response_model=TrainResponse, status_code=status.HTTP_200_OK) async def train_model( file: UploadFile = File(..., description="CSV file to train on"), options: TrainOptions = Depends(), state: ModelState = Depends(get_model_state) ) -> TrainResponse: """Train a Categorical Naive Bayes model on the uploaded CSV. Parameters: - file: CSV file with the training data - options: Training options specifying target column and feature columns """ try: contents = await file.read() try: df = pd.read_csv(io.StringIO(contents.decode('utf-8'))) except UnicodeDecodeError: df = pd.read_csv(io.StringIO(contents.decode('latin-1'))) # Validate columns exist in the DataFrame missing_columns = [col for col in [options.target_column] + options.feature_columns if col not in df.columns] if missing_columns: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Columns not found in CSV: {', '.join(missing_columns)}" ) # Initialize data structures X = pd.DataFrame() state.feature_encoders = {} # Encode features for column in options.feature_columns: if df[column].isna().any(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Column '{column}' contains missing values. Please preprocess your data." ) encoder = LabelEncoder() X[column] = encoder.fit_transform(df[column]) state.feature_encoders[column] = encoder # Encode target if df[options.target_column].isna().any(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Target column '{options.target_column}' contains missing values." ) state.target_encoder = LabelEncoder() y = state.target_encoder.fit_transform(df[options.target_column]) # Train/test split X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42, stratify=y if len(np.unique(y)) > 1 else None ) # Train the model state.model = CategoricalNB() state.model.fit(X_train, y_train) accuracy = float(state.model.score(X_test, y_test)) state.X_test = X_test state.y_test = y_test return TrainResponse( message="Model trained successfully", accuracy=accuracy, target_classes=list(state.target_encoder.classes_) ) except HTTPException: raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Model training failed: {str(e)}" ) @app.post("/api/predict", tags=["Model"], summary="Predict", response_model=PredictResponse, status_code=status.HTTP_200_OK) async def predict( features: PredictionFeatures, state: ModelState = Depends(get_model_state) ) -> PredictResponse: """Predict the target class for given features using the trained model. Parameters: - features: Dictionary of feature values for prediction """ if state.model is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Model not trained yet. Please train a model before making predictions." ) try: # Validate that all required features are provided required_features = set(state.feature_encoders.keys()) provided_features = set(features.features.keys()) missing_features = required_features - provided_features if missing_features: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Missing required features: {', '.join(missing_features)}" ) # Validate extra features extra_features = provided_features - required_features if extra_features: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unexpected features provided: {', '.join(extra_features)}" ) # Encode features encoded_features = {} for column, value in features.features.items(): try: encoded_features[column] = state.feature_encoders[column].transform([value])[0] except ValueError: # Handle unknown category values raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unknown value '{value}' for feature '{column}'. Allowed values: {', '.join(map(str, state.feature_encoders[column].classes_))}" ) # Make prediction X = pd.DataFrame([encoded_features]) prediction = state.model.predict(X) prediction_proba = state.model.predict_proba(X) predicted_class = state.target_encoder.inverse_transform(prediction)[0] # Generate probabilities class_probabilities = { state.target_encoder.inverse_transform([i])[0]: float(prob) for i, prob in enumerate(prediction_proba[0]) } return PredictResponse( prediction=predicted_class, probabilities=class_probabilities ) except HTTPException: raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Prediction failed: {str(e)}" ) @app.get( "/api/plot/confusion-matrix", tags=["Model"], summary="Confusion Matrix Plot", response_class=StreamingResponse, responses={ 200: { "content": {"image/png": {}}, "description": "PNG image of confusion matrix" }, 400: { "description": "Model not trained or no test data available" } } ) async def plot_confusion_matrix(state: ModelState = Depends(get_model_state)): """Return a PNG image of the confusion matrix for the last test set.""" if state.model is None or state.X_test is None or state.y_test is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Model not trained or no test data available." ) try: y_pred = state.model.predict(state.X_test) cm = confusion_matrix(state.y_test, y_pred) # Create plot fig, ax = plt.subplots(figsize=(7, 6)) cax = ax.matshow(cm, cmap=plt.cm.Blues) plt.title('Confusion Matrix') plt.xlabel('Predicted') plt.ylabel('Actual') plt.colorbar(cax) # Add labels classes = state.target_encoder.classes_ if state.target_encoder else [] ax.set_xticks(np.arange(len(classes))) ax.set_yticks(np.arange(len(classes))) ax.set_xticklabels(classes, rotation=45, ha="left") ax.set_yticklabels(classes) # Add numbers to the plot for (i, j), z in np.ndenumerate(cm): ax.text(j, i, str(z), ha='center', va='center', color='white' if cm[i, j] > cm.max() / 2 else 'black') plt.tight_layout() # Save to buffer buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150) plt.close(fig) buf.seek(0) # Create response with cache control headers response = StreamingResponse(buf, media_type="image/png") response.headers["Cache-Control"] = "max-age=3600" # Cache for 1 hour return response except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to generate confusion matrix: {str(e)}" ) @app.get( "/api/plot/feature-log-prob", tags=["Model"], summary="Feature Log Probability Heatmap", response_class=StreamingResponse, responses={ 200: { "content": {"image/png": {}}, "description": "PNG heatmap of feature log probabilities" }, 400: { "description": "Model not trained" } } ) async def plot_feature_log_prob(state: ModelState = Depends(get_model_state)): """Return a PNG heatmap of feature log probabilities for each class.""" if state.model is None or state.target_encoder is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Model not trained." ) try: feature_names = list(state.feature_encoders.keys()) class_names = list(state.target_encoder.classes_) # Calculate plot size based on data fig_height = max(4, 2 * len(feature_names)) fig, axes = plt.subplots(len(feature_names), 1, figsize=(10, fig_height)) if len(feature_names) == 1: axes = [axes] for idx, feature in enumerate(feature_names): encoder = state.feature_encoders[feature] categories = encoder.classes_ data = [] for class_idx, class_name in enumerate(class_names): # For each class, get the log prob for each value of this feature log_probs = state.model.feature_log_prob_[class_idx, idx, :] data.append(log_probs) data = np.array(data) ax = axes[idx] # Create heatmap sns.heatmap( data, annot=True, fmt=".2f", cmap="Blues", xticklabels=categories, yticklabels=class_names, ax=ax ) ax.set_title(f'Log Probabilities for Feature: {feature}') ax.set_xlabel('Feature Value') ax.set_ylabel('Class') plt.tight_layout() # Save to buffer buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150) plt.close(fig) buf.seek(0) # Create response with cache control headers response = StreamingResponse(buf, media_type="image/png") response.headers["Cache-Control"] = "max-age=3600" # Cache for 1 hour return response except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to generate feature log probability plot: {str(e)}" ) if __name__ == "__main__": import uvicorn import os # Get port from environment variable or default to 7860 (for HF Spaces) port = int(os.environ.get("PORT", 7860)) # Configure logging for better visibility log_config = uvicorn.config.LOGGING_CONFIG log_config["formatters"]["access"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s" log_config["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s" uvicorn.run( "app:app", host="0.0.0.0", port=port, log_level="info", reload=True, log_config=log_config )