nggox-fastapi / app.py
xuanzang's picture
1
545fc5b
from fastapi import FastAPI, UploadFile, Form, HTTPException, Depends, status, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
import pandas as pd
import numpy as np
from sklearn.naive_bayes import CategoricalNB
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import json
import io
from typing import Dict, List, Optional, Any
from pydantic import BaseModel, Field
import matplotlib.pyplot as plt
import seaborn as sns
from fastapi.encoders import jsonable_encoder
app = FastAPI(
title="Categorical Naive Bayes API",
description="API for uploading CSVs, training a Categorical Naive Bayes model, and making predictions.",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class TrainOptions(BaseModel):
target_column: str = Field(..., description="The name of the target column.")
feature_columns: List[str] = Field(..., description="List of feature column names.")
class PredictionFeatures(BaseModel):
features: Dict[str, str] = Field(..., description="Dictionary of feature values for prediction.")
class UploadResponse(BaseModel):
message: str
columns: List[str]
column_types: Dict[str, str]
unique_values: Dict[str, List[Any]]
row_count: int
class TrainResponse(BaseModel):
message: str
accuracy: float
target_classes: List[str]
class PredictResponse(BaseModel):
prediction: str
probabilities: Dict[str, float]
class ModelState:
def __init__(self):
self.model: Optional[CategoricalNB] = None
self.feature_encoders: Dict[str, LabelEncoder] = {}
self.target_encoder: Optional[LabelEncoder] = None
self.X_test: Optional[pd.DataFrame] = None
self.y_test: Optional[np.ndarray] = None
model_state = ModelState()
def get_model_state():
return model_state
@app.get("/api/health", tags=["Health"], summary="Health Check", response_model=Dict[str, str])
async def health_check():
"""Check API health."""
return {"status": "healthy"}
@app.post("/api/upload", tags=["Data"], summary="Upload CSV File", response_model=UploadResponse, status_code=status.HTTP_200_OK)
async def upload_csv(
file: UploadFile = File(..., description="CSV file to upload")
) -> UploadResponse:
"""Upload a CSV file and get metadata about its columns."""
if not file.filename or not file.filename.lower().endswith('.csv'):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only CSV files are allowed"
)
try:
contents = await file.read()
# Check if the file content is valid
if len(contents) == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Uploaded file is empty"
)
# Try to parse the CSV
try:
df = pd.read_csv(io.StringIO(contents.decode('utf-8')))
except UnicodeDecodeError:
# Try another encoding if UTF-8 fails
try:
df = pd.read_csv(io.StringIO(contents.decode('latin-1')))
except Exception:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Unable to decode CSV file. Please ensure it's properly formatted."
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Error parsing CSV: {str(e)}"
)
if df.empty:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="CSV file contains no data"
)
# Process the data
columns = df.columns.tolist()
column_types = {col: str(df[col].dtype) for col in columns}
# Limit the number of unique values to prevent excessive response sizes
unique_values = {}
for col in columns:
unique_vals = df[col].unique().tolist()
# Limit to 100 values max to prevent excessive response size
if len(unique_vals) > 100:
unique_values[col] = unique_vals[:100] + ["... (truncated)"]
else:
unique_values[col] = unique_vals
# Convert NumPy objects to Python native types
for col, values in unique_values.items():
unique_values[col] = [v.item() if isinstance(v, np.generic) else v for v in values]
return UploadResponse(
message="File uploaded and processed successfully",
columns=columns,
column_types=column_types,
unique_values=unique_values,
row_count=len(df)
)
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An unexpected error occurred: {str(e)}"
)
@app.post("/api/train", tags=["Model"], summary="Train Model", response_model=TrainResponse, status_code=status.HTTP_200_OK)
async def train_model(
file: UploadFile = File(..., description="CSV file to train on"),
options: TrainOptions = Depends(),
state: ModelState = Depends(get_model_state)
) -> TrainResponse:
"""Train a Categorical Naive Bayes model on the uploaded CSV.
Parameters:
- file: CSV file with the training data
- options: Training options specifying target column and feature columns
"""
try:
contents = await file.read()
try:
df = pd.read_csv(io.StringIO(contents.decode('utf-8')))
except UnicodeDecodeError:
df = pd.read_csv(io.StringIO(contents.decode('latin-1')))
# Validate columns exist in the DataFrame
missing_columns = [col for col in [options.target_column] + options.feature_columns
if col not in df.columns]
if missing_columns:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Columns not found in CSV: {', '.join(missing_columns)}"
)
# Initialize data structures
X = pd.DataFrame()
state.feature_encoders = {}
# Encode features
for column in options.feature_columns:
if df[column].isna().any():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Column '{column}' contains missing values. Please preprocess your data."
)
encoder = LabelEncoder()
X[column] = encoder.fit_transform(df[column])
state.feature_encoders[column] = encoder
# Encode target
if df[options.target_column].isna().any():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Target column '{options.target_column}' contains missing values."
)
state.target_encoder = LabelEncoder()
y = state.target_encoder.fit_transform(df[options.target_column])
# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y if len(np.unique(y)) > 1 else None
)
# Train the model
state.model = CategoricalNB()
state.model.fit(X_train, y_train)
accuracy = float(state.model.score(X_test, y_test))
state.X_test = X_test
state.y_test = y_test
return TrainResponse(
message="Model trained successfully",
accuracy=accuracy,
target_classes=list(state.target_encoder.classes_)
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Model training failed: {str(e)}"
)
@app.post("/api/predict", tags=["Model"], summary="Predict", response_model=PredictResponse, status_code=status.HTTP_200_OK)
async def predict(
features: PredictionFeatures,
state: ModelState = Depends(get_model_state)
) -> PredictResponse:
"""Predict the target class for given features using the trained model.
Parameters:
- features: Dictionary of feature values for prediction
"""
if state.model is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Model not trained yet. Please train a model before making predictions."
)
try:
# Validate that all required features are provided
required_features = set(state.feature_encoders.keys())
provided_features = set(features.features.keys())
missing_features = required_features - provided_features
if missing_features:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Missing required features: {', '.join(missing_features)}"
)
# Validate extra features
extra_features = provided_features - required_features
if extra_features:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unexpected features provided: {', '.join(extra_features)}"
)
# Encode features
encoded_features = {}
for column, value in features.features.items():
try:
encoded_features[column] = state.feature_encoders[column].transform([value])[0]
except ValueError:
# Handle unknown category values
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unknown value '{value}' for feature '{column}'. Allowed values: {', '.join(map(str, state.feature_encoders[column].classes_))}"
)
# Make prediction
X = pd.DataFrame([encoded_features])
prediction = state.model.predict(X)
prediction_proba = state.model.predict_proba(X)
predicted_class = state.target_encoder.inverse_transform(prediction)[0]
# Generate probabilities
class_probabilities = {
state.target_encoder.inverse_transform([i])[0]: float(prob)
for i, prob in enumerate(prediction_proba[0])
}
return PredictResponse(
prediction=predicted_class,
probabilities=class_probabilities
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Prediction failed: {str(e)}"
)
@app.get(
"/api/plot/confusion-matrix",
tags=["Model"],
summary="Confusion Matrix Plot",
response_class=StreamingResponse,
responses={
200: {
"content": {"image/png": {}},
"description": "PNG image of confusion matrix"
},
400: {
"description": "Model not trained or no test data available"
}
}
)
async def plot_confusion_matrix(state: ModelState = Depends(get_model_state)):
"""Return a PNG image of the confusion matrix for the last test set."""
if state.model is None or state.X_test is None or state.y_test is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Model not trained or no test data available."
)
try:
y_pred = state.model.predict(state.X_test)
cm = confusion_matrix(state.y_test, y_pred)
# Create plot
fig, ax = plt.subplots(figsize=(7, 6))
cax = ax.matshow(cm, cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.colorbar(cax)
# Add labels
classes = state.target_encoder.classes_ if state.target_encoder else []
ax.set_xticks(np.arange(len(classes)))
ax.set_yticks(np.arange(len(classes)))
ax.set_xticklabels(classes, rotation=45, ha="left")
ax.set_yticklabels(classes)
# Add numbers to the plot
for (i, j), z in np.ndenumerate(cm):
ax.text(j, i, str(z), ha='center', va='center',
color='white' if cm[i, j] > cm.max() / 2 else 'black')
plt.tight_layout()
# Save to buffer
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150)
plt.close(fig)
buf.seek(0)
# Create response with cache control headers
response = StreamingResponse(buf, media_type="image/png")
response.headers["Cache-Control"] = "max-age=3600" # Cache for 1 hour
return response
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to generate confusion matrix: {str(e)}"
)
@app.get(
"/api/plot/feature-log-prob",
tags=["Model"],
summary="Feature Log Probability Heatmap",
response_class=StreamingResponse,
responses={
200: {
"content": {"image/png": {}},
"description": "PNG heatmap of feature log probabilities"
},
400: {
"description": "Model not trained"
}
}
)
async def plot_feature_log_prob(state: ModelState = Depends(get_model_state)):
"""Return a PNG heatmap of feature log probabilities for each class."""
if state.model is None or state.target_encoder is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Model not trained."
)
try:
feature_names = list(state.feature_encoders.keys())
class_names = list(state.target_encoder.classes_)
# Calculate plot size based on data
fig_height = max(4, 2 * len(feature_names))
fig, axes = plt.subplots(len(feature_names), 1, figsize=(10, fig_height))
if len(feature_names) == 1:
axes = [axes]
for idx, feature in enumerate(feature_names):
encoder = state.feature_encoders[feature]
categories = encoder.classes_
data = []
for class_idx, class_name in enumerate(class_names):
# For each class, get the log prob for each value of this feature
log_probs = state.model.feature_log_prob_[class_idx, idx, :]
data.append(log_probs)
data = np.array(data)
ax = axes[idx]
# Create heatmap
sns.heatmap(
data,
annot=True,
fmt=".2f",
cmap="Blues",
xticklabels=categories,
yticklabels=class_names,
ax=ax
)
ax.set_title(f'Log Probabilities for Feature: {feature}')
ax.set_xlabel('Feature Value')
ax.set_ylabel('Class')
plt.tight_layout()
# Save to buffer
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150)
plt.close(fig)
buf.seek(0)
# Create response with cache control headers
response = StreamingResponse(buf, media_type="image/png")
response.headers["Cache-Control"] = "max-age=3600" # Cache for 1 hour
return response
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to generate feature log probability plot: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
import os
# Get port from environment variable or default to 7860 (for HF Spaces)
port = int(os.environ.get("PORT", 7860))
# Configure logging for better visibility
log_config = uvicorn.config.LOGGING_CONFIG
log_config["formatters"]["access"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s"
log_config["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s"
uvicorn.run(
"app:app",
host="0.0.0.0",
port=port,
log_level="info",
reload=True,
log_config=log_config
)