Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, Form, HTTPException, Depends, status, File | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse, StreamingResponse | |
import pandas as pd | |
import numpy as np | |
from sklearn.naive_bayes import CategoricalNB | |
from sklearn.preprocessing import LabelEncoder | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import confusion_matrix | |
import json | |
import io | |
from typing import Dict, List, Optional, Any | |
from pydantic import BaseModel, Field | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from fastapi.encoders import jsonable_encoder | |
app = FastAPI( | |
title="Categorical Naive Bayes API", | |
description="API for uploading CSVs, training a Categorical Naive Bayes model, and making predictions.", | |
version="1.0.0" | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class TrainOptions(BaseModel): | |
target_column: str = Field(..., description="The name of the target column.") | |
feature_columns: List[str] = Field(..., description="List of feature column names.") | |
class PredictionFeatures(BaseModel): | |
features: Dict[str, str] = Field(..., description="Dictionary of feature values for prediction.") | |
class UploadResponse(BaseModel): | |
message: str | |
columns: List[str] | |
column_types: Dict[str, str] | |
unique_values: Dict[str, List[Any]] | |
row_count: int | |
class TrainResponse(BaseModel): | |
message: str | |
accuracy: float | |
target_classes: List[str] | |
class PredictResponse(BaseModel): | |
prediction: str | |
probabilities: Dict[str, float] | |
class ModelState: | |
def __init__(self): | |
self.model: Optional[CategoricalNB] = None | |
self.feature_encoders: Dict[str, LabelEncoder] = {} | |
self.target_encoder: Optional[LabelEncoder] = None | |
self.X_test: Optional[pd.DataFrame] = None | |
self.y_test: Optional[np.ndarray] = None | |
model_state = ModelState() | |
def get_model_state(): | |
return model_state | |
async def health_check(): | |
"""Check API health.""" | |
return {"status": "healthy"} | |
async def upload_csv( | |
file: UploadFile = File(..., description="CSV file to upload") | |
) -> UploadResponse: | |
"""Upload a CSV file and get metadata about its columns.""" | |
if not file.filename or not file.filename.lower().endswith('.csv'): | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Only CSV files are allowed" | |
) | |
try: | |
contents = await file.read() | |
# Check if the file content is valid | |
if len(contents) == 0: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Uploaded file is empty" | |
) | |
# Try to parse the CSV | |
try: | |
df = pd.read_csv(io.StringIO(contents.decode('utf-8'))) | |
except UnicodeDecodeError: | |
# Try another encoding if UTF-8 fails | |
try: | |
df = pd.read_csv(io.StringIO(contents.decode('latin-1'))) | |
except Exception: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Unable to decode CSV file. Please ensure it's properly formatted." | |
) | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=f"Error parsing CSV: {str(e)}" | |
) | |
if df.empty: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="CSV file contains no data" | |
) | |
# Process the data | |
columns = df.columns.tolist() | |
column_types = {col: str(df[col].dtype) for col in columns} | |
# Limit the number of unique values to prevent excessive response sizes | |
unique_values = {} | |
for col in columns: | |
unique_vals = df[col].unique().tolist() | |
# Limit to 100 values max to prevent excessive response size | |
if len(unique_vals) > 100: | |
unique_values[col] = unique_vals[:100] + ["... (truncated)"] | |
else: | |
unique_values[col] = unique_vals | |
# Convert NumPy objects to Python native types | |
for col, values in unique_values.items(): | |
unique_values[col] = [v.item() if isinstance(v, np.generic) else v for v in values] | |
return UploadResponse( | |
message="File uploaded and processed successfully", | |
columns=columns, | |
column_types=column_types, | |
unique_values=unique_values, | |
row_count=len(df) | |
) | |
except HTTPException: | |
# Re-raise HTTP exceptions | |
raise | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"An unexpected error occurred: {str(e)}" | |
) | |
async def train_model( | |
file: UploadFile = File(..., description="CSV file to train on"), | |
options: TrainOptions = Depends(), | |
state: ModelState = Depends(get_model_state) | |
) -> TrainResponse: | |
"""Train a Categorical Naive Bayes model on the uploaded CSV. | |
Parameters: | |
- file: CSV file with the training data | |
- options: Training options specifying target column and feature columns | |
""" | |
try: | |
contents = await file.read() | |
try: | |
df = pd.read_csv(io.StringIO(contents.decode('utf-8'))) | |
except UnicodeDecodeError: | |
df = pd.read_csv(io.StringIO(contents.decode('latin-1'))) | |
# Validate columns exist in the DataFrame | |
missing_columns = [col for col in [options.target_column] + options.feature_columns | |
if col not in df.columns] | |
if missing_columns: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=f"Columns not found in CSV: {', '.join(missing_columns)}" | |
) | |
# Initialize data structures | |
X = pd.DataFrame() | |
state.feature_encoders = {} | |
# Encode features | |
for column in options.feature_columns: | |
if df[column].isna().any(): | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=f"Column '{column}' contains missing values. Please preprocess your data." | |
) | |
encoder = LabelEncoder() | |
X[column] = encoder.fit_transform(df[column]) | |
state.feature_encoders[column] = encoder | |
# Encode target | |
if df[options.target_column].isna().any(): | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=f"Target column '{options.target_column}' contains missing values." | |
) | |
state.target_encoder = LabelEncoder() | |
y = state.target_encoder.fit_transform(df[options.target_column]) | |
# Train/test split | |
X_train, X_test, y_train, y_test = train_test_split( | |
X, y, test_size=0.2, random_state=42, stratify=y if len(np.unique(y)) > 1 else None | |
) | |
# Train the model | |
state.model = CategoricalNB() | |
state.model.fit(X_train, y_train) | |
accuracy = float(state.model.score(X_test, y_test)) | |
state.X_test = X_test | |
state.y_test = y_test | |
return TrainResponse( | |
message="Model trained successfully", | |
accuracy=accuracy, | |
target_classes=list(state.target_encoder.classes_) | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Model training failed: {str(e)}" | |
) | |
async def predict( | |
features: PredictionFeatures, | |
state: ModelState = Depends(get_model_state) | |
) -> PredictResponse: | |
"""Predict the target class for given features using the trained model. | |
Parameters: | |
- features: Dictionary of feature values for prediction | |
""" | |
if state.model is None: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Model not trained yet. Please train a model before making predictions." | |
) | |
try: | |
# Validate that all required features are provided | |
required_features = set(state.feature_encoders.keys()) | |
provided_features = set(features.features.keys()) | |
missing_features = required_features - provided_features | |
if missing_features: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=f"Missing required features: {', '.join(missing_features)}" | |
) | |
# Validate extra features | |
extra_features = provided_features - required_features | |
if extra_features: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=f"Unexpected features provided: {', '.join(extra_features)}" | |
) | |
# Encode features | |
encoded_features = {} | |
for column, value in features.features.items(): | |
try: | |
encoded_features[column] = state.feature_encoders[column].transform([value])[0] | |
except ValueError: | |
# Handle unknown category values | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=f"Unknown value '{value}' for feature '{column}'. Allowed values: {', '.join(map(str, state.feature_encoders[column].classes_))}" | |
) | |
# Make prediction | |
X = pd.DataFrame([encoded_features]) | |
prediction = state.model.predict(X) | |
prediction_proba = state.model.predict_proba(X) | |
predicted_class = state.target_encoder.inverse_transform(prediction)[0] | |
# Generate probabilities | |
class_probabilities = { | |
state.target_encoder.inverse_transform([i])[0]: float(prob) | |
for i, prob in enumerate(prediction_proba[0]) | |
} | |
return PredictResponse( | |
prediction=predicted_class, | |
probabilities=class_probabilities | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Prediction failed: {str(e)}" | |
) | |
async def plot_confusion_matrix(state: ModelState = Depends(get_model_state)): | |
"""Return a PNG image of the confusion matrix for the last test set.""" | |
if state.model is None or state.X_test is None or state.y_test is None: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Model not trained or no test data available." | |
) | |
try: | |
y_pred = state.model.predict(state.X_test) | |
cm = confusion_matrix(state.y_test, y_pred) | |
# Create plot | |
fig, ax = plt.subplots(figsize=(7, 6)) | |
cax = ax.matshow(cm, cmap=plt.cm.Blues) | |
plt.title('Confusion Matrix') | |
plt.xlabel('Predicted') | |
plt.ylabel('Actual') | |
plt.colorbar(cax) | |
# Add labels | |
classes = state.target_encoder.classes_ if state.target_encoder else [] | |
ax.set_xticks(np.arange(len(classes))) | |
ax.set_yticks(np.arange(len(classes))) | |
ax.set_xticklabels(classes, rotation=45, ha="left") | |
ax.set_yticklabels(classes) | |
# Add numbers to the plot | |
for (i, j), z in np.ndenumerate(cm): | |
ax.text(j, i, str(z), ha='center', va='center', | |
color='white' if cm[i, j] > cm.max() / 2 else 'black') | |
plt.tight_layout() | |
# Save to buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', dpi=150) | |
plt.close(fig) | |
buf.seek(0) | |
# Create response with cache control headers | |
response = StreamingResponse(buf, media_type="image/png") | |
response.headers["Cache-Control"] = "max-age=3600" # Cache for 1 hour | |
return response | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Failed to generate confusion matrix: {str(e)}" | |
) | |
async def plot_feature_log_prob(state: ModelState = Depends(get_model_state)): | |
"""Return a PNG heatmap of feature log probabilities for each class.""" | |
if state.model is None or state.target_encoder is None: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Model not trained." | |
) | |
try: | |
feature_names = list(state.feature_encoders.keys()) | |
class_names = list(state.target_encoder.classes_) | |
# Calculate plot size based on data | |
fig_height = max(4, 2 * len(feature_names)) | |
fig, axes = plt.subplots(len(feature_names), 1, figsize=(10, fig_height)) | |
if len(feature_names) == 1: | |
axes = [axes] | |
for idx, feature in enumerate(feature_names): | |
encoder = state.feature_encoders[feature] | |
categories = encoder.classes_ | |
data = [] | |
for class_idx, class_name in enumerate(class_names): | |
# For each class, get the log prob for each value of this feature | |
log_probs = state.model.feature_log_prob_[class_idx, idx, :] | |
data.append(log_probs) | |
data = np.array(data) | |
ax = axes[idx] | |
# Create heatmap | |
sns.heatmap( | |
data, | |
annot=True, | |
fmt=".2f", | |
cmap="Blues", | |
xticklabels=categories, | |
yticklabels=class_names, | |
ax=ax | |
) | |
ax.set_title(f'Log Probabilities for Feature: {feature}') | |
ax.set_xlabel('Feature Value') | |
ax.set_ylabel('Class') | |
plt.tight_layout() | |
# Save to buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', dpi=150) | |
plt.close(fig) | |
buf.seek(0) | |
# Create response with cache control headers | |
response = StreamingResponse(buf, media_type="image/png") | |
response.headers["Cache-Control"] = "max-age=3600" # Cache for 1 hour | |
return response | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Failed to generate feature log probability plot: {str(e)}" | |
) | |
if __name__ == "__main__": | |
import uvicorn | |
import os | |
# Get port from environment variable or default to 7860 (for HF Spaces) | |
port = int(os.environ.get("PORT", 7860)) | |
# Configure logging for better visibility | |
log_config = uvicorn.config.LOGGING_CONFIG | |
log_config["formatters"]["access"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s" | |
log_config["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s" | |
uvicorn.run( | |
"app:app", | |
host="0.0.0.0", | |
port=port, | |
log_level="info", | |
reload=True, | |
log_config=log_config | |
) |