Spaces:
Sleeping
Sleeping
Enhance FastAPI application with model training and prediction features, including detailed response models and health check endpoint.
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from fastapi import FastAPI, UploadFile, Form, HTTPException
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
from fastapi.responses import JSONResponse
|
4 |
import pandas as pd
|
@@ -9,10 +9,14 @@ from sklearn.model_selection import train_test_split
|
|
9 |
from sklearn.metrics import confusion_matrix
|
10 |
import json
|
11 |
import io
|
12 |
-
from typing import Dict, List, Optional
|
13 |
-
from pydantic import BaseModel
|
14 |
|
15 |
-
app = FastAPI(
|
|
|
|
|
|
|
|
|
16 |
|
17 |
app.add_middleware(
|
18 |
CORSMiddleware,
|
@@ -22,117 +26,200 @@ app.add_middleware(
|
|
22 |
allow_headers=["*"],
|
23 |
)
|
24 |
|
25 |
-
model = None
|
26 |
-
feature_encoders: Dict[str, LabelEncoder] = {}
|
27 |
-
target_encoder: Optional[LabelEncoder] = None
|
28 |
-
|
29 |
class TrainOptions(BaseModel):
|
30 |
-
target_column: str
|
31 |
-
feature_columns: List[str]
|
32 |
|
33 |
class PredictionFeatures(BaseModel):
|
34 |
-
features: Dict[str, str]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
async def health_check():
|
|
|
38 |
return {"status": "healthy"}
|
39 |
|
40 |
-
@app.post("/api/upload")
|
41 |
-
async def upload_csv(file: UploadFile):
|
|
|
42 |
if not file.filename.endswith('.csv'):
|
43 |
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
|
44 |
-
|
45 |
try:
|
46 |
contents = await file.read()
|
47 |
df = pd.read_csv(io.StringIO(contents.decode()))
|
48 |
-
|
49 |
columns = df.columns.tolist()
|
50 |
column_types = {col: str(df[col].dtype) for col in columns}
|
51 |
-
|
52 |
unique_values = {col: df[col].unique().tolist() for col in columns}
|
53 |
-
|
54 |
for col, values in unique_values.items():
|
55 |
unique_values[col] = [v.item() if isinstance(v, np.generic) else v for v in values]
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
}
|
64 |
except Exception as e:
|
65 |
raise HTTPException(status_code=500, detail=str(e))
|
66 |
|
67 |
-
@app.post("/api/train")
|
68 |
-
async def train_model(
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
71 |
try:
|
72 |
train_options = json.loads(options)
|
73 |
target_column = train_options["target_column"]
|
74 |
feature_columns = train_options["feature_columns"]
|
75 |
-
|
76 |
contents = await file.read()
|
77 |
df = pd.read_csv(io.StringIO(contents.decode()))
|
78 |
-
|
79 |
X = pd.DataFrame()
|
80 |
-
feature_encoders = {}
|
81 |
for column in feature_columns:
|
82 |
encoder = LabelEncoder()
|
83 |
X[column] = encoder.fit_transform(df[column])
|
84 |
-
feature_encoders[column] = encoder
|
85 |
-
|
86 |
-
|
87 |
-
y = target_encoder.fit_transform(df[target_column])
|
88 |
-
|
89 |
X_train, X_test, y_train, y_test = train_test_split(
|
90 |
X, y, test_size=0.2, random_state=42
|
91 |
)
|
92 |
-
|
93 |
-
model
|
94 |
-
model.
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
}
|
103 |
-
|
104 |
except Exception as e:
|
105 |
raise HTTPException(status_code=500, detail=str(e))
|
106 |
|
107 |
-
@app.post("/api/predict")
|
108 |
-
async def predict(
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
112 |
raise HTTPException(status_code=400, detail="Model not trained yet")
|
113 |
-
|
114 |
try:
|
115 |
encoded_features = {}
|
116 |
for column, value in features.features.items():
|
117 |
-
if column in feature_encoders:
|
118 |
-
encoded_features[column] = feature_encoders[column].transform([value])[0]
|
119 |
-
|
120 |
X = pd.DataFrame([encoded_features])
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
predicted_class = target_encoder.inverse_transform(prediction)[0]
|
126 |
-
|
127 |
class_probabilities = {
|
128 |
-
target_encoder.inverse_transform([i])[0]: float(prob)
|
129 |
for i, prob in enumerate(prediction_proba[0])
|
130 |
}
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
except Exception as e:
|
137 |
raise HTTPException(status_code=500, detail=str(e))
|
138 |
|
|
|
1 |
+
from fastapi import FastAPI, UploadFile, Form, HTTPException, Depends, status
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
from fastapi.responses import JSONResponse
|
4 |
import pandas as pd
|
|
|
9 |
from sklearn.metrics import confusion_matrix
|
10 |
import json
|
11 |
import io
|
12 |
+
from typing import Dict, List, Optional, Any
|
13 |
+
from pydantic import BaseModel, Field
|
14 |
|
15 |
+
app = FastAPI(
|
16 |
+
title="Categorical Naive Bayes API",
|
17 |
+
description="API for uploading CSVs, training a Categorical Naive Bayes model, and making predictions.",
|
18 |
+
version="1.0.0"
|
19 |
+
)
|
20 |
|
21 |
app.add_middleware(
|
22 |
CORSMiddleware,
|
|
|
26 |
allow_headers=["*"],
|
27 |
)
|
28 |
|
|
|
|
|
|
|
|
|
29 |
class TrainOptions(BaseModel):
|
30 |
+
target_column: str = Field(..., description="The name of the target column.")
|
31 |
+
feature_columns: List[str] = Field(..., description="List of feature column names.")
|
32 |
|
33 |
class PredictionFeatures(BaseModel):
|
34 |
+
features: Dict[str, str] = Field(..., description="Dictionary of feature values for prediction.")
|
35 |
+
|
36 |
+
class UploadResponse(BaseModel):
|
37 |
+
message: str
|
38 |
+
columns: List[str]
|
39 |
+
column_types: Dict[str, str]
|
40 |
+
unique_values: Dict[str, List[Any]]
|
41 |
+
row_count: int
|
42 |
+
|
43 |
+
class TrainResponse(BaseModel):
|
44 |
+
message: str
|
45 |
+
accuracy: float
|
46 |
+
target_classes: List[str]
|
47 |
+
|
48 |
+
class PredictResponse(BaseModel):
|
49 |
+
prediction: str
|
50 |
+
probabilities: Dict[str, float]
|
51 |
|
52 |
+
class ModelState:
|
53 |
+
def __init__(self):
|
54 |
+
self.model: Optional[CategoricalNB] = None
|
55 |
+
self.feature_encoders: Dict[str, LabelEncoder] = {}
|
56 |
+
self.target_encoder: Optional[LabelEncoder] = None
|
57 |
+
self.X_test: Optional[pd.DataFrame] = None
|
58 |
+
self.y_test: Optional[np.ndarray] = None
|
59 |
+
|
60 |
+
model_state = ModelState()
|
61 |
+
|
62 |
+
def get_model_state():
|
63 |
+
return model_state
|
64 |
+
|
65 |
+
@app.get("/api/health", tags=["Health"], summary="Health Check", response_model=Dict[str, str])
|
66 |
async def health_check():
|
67 |
+
"""Check API health."""
|
68 |
return {"status": "healthy"}
|
69 |
|
70 |
+
@app.post("/api/upload", tags=["Data"], summary="Upload CSV File", response_model=UploadResponse, status_code=status.HTTP_200_OK)
|
71 |
+
async def upload_csv(file: UploadFile) -> UploadResponse:
|
72 |
+
"""Upload a CSV file and get metadata about its columns."""
|
73 |
if not file.filename.endswith('.csv'):
|
74 |
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
|
|
|
75 |
try:
|
76 |
contents = await file.read()
|
77 |
df = pd.read_csv(io.StringIO(contents.decode()))
|
|
|
78 |
columns = df.columns.tolist()
|
79 |
column_types = {col: str(df[col].dtype) for col in columns}
|
|
|
80 |
unique_values = {col: df[col].unique().tolist() for col in columns}
|
|
|
81 |
for col, values in unique_values.items():
|
82 |
unique_values[col] = [v.item() if isinstance(v, np.generic) else v for v in values]
|
83 |
+
return UploadResponse(
|
84 |
+
message="File uploaded successfully",
|
85 |
+
columns=columns,
|
86 |
+
column_types=column_types,
|
87 |
+
unique_values=unique_values,
|
88 |
+
row_count=len(df)
|
89 |
+
)
|
|
|
90 |
except Exception as e:
|
91 |
raise HTTPException(status_code=500, detail=str(e))
|
92 |
|
93 |
+
@app.post("/api/train", tags=["Model"], summary="Train Model", response_model=TrainResponse, status_code=status.HTTP_200_OK)
|
94 |
+
async def train_model(
|
95 |
+
file: UploadFile,
|
96 |
+
options: str = Form(...),
|
97 |
+
state: ModelState = Depends(get_model_state)
|
98 |
+
) -> TrainResponse:
|
99 |
+
"""Train a Categorical Naive Bayes model on the uploaded CSV."""
|
100 |
try:
|
101 |
train_options = json.loads(options)
|
102 |
target_column = train_options["target_column"]
|
103 |
feature_columns = train_options["feature_columns"]
|
|
|
104 |
contents = await file.read()
|
105 |
df = pd.read_csv(io.StringIO(contents.decode()))
|
|
|
106 |
X = pd.DataFrame()
|
107 |
+
state.feature_encoders = {}
|
108 |
for column in feature_columns:
|
109 |
encoder = LabelEncoder()
|
110 |
X[column] = encoder.fit_transform(df[column])
|
111 |
+
state.feature_encoders[column] = encoder
|
112 |
+
state.target_encoder = LabelEncoder()
|
113 |
+
y = state.target_encoder.fit_transform(df[target_column])
|
|
|
|
|
114 |
X_train, X_test, y_train, y_test = train_test_split(
|
115 |
X, y, test_size=0.2, random_state=42
|
116 |
)
|
117 |
+
state.model = CategoricalNB()
|
118 |
+
state.model.fit(X_train, y_train)
|
119 |
+
accuracy = float(state.model.score(X_test, y_test))
|
120 |
+
state.X_test = X_test
|
121 |
+
state.y_test = y_test
|
122 |
+
return TrainResponse(
|
123 |
+
message="Model trained successfully",
|
124 |
+
accuracy=accuracy,
|
125 |
+
target_classes=list(state.target_encoder.classes_)
|
126 |
+
)
|
|
|
|
|
127 |
except Exception as e:
|
128 |
raise HTTPException(status_code=500, detail=str(e))
|
129 |
|
130 |
+
@app.post("/api/predict", tags=["Model"], summary="Predict", response_model=PredictResponse, status_code=status.HTTP_200_OK)
|
131 |
+
async def predict(
|
132 |
+
features: PredictionFeatures,
|
133 |
+
state: ModelState = Depends(get_model_state)
|
134 |
+
) -> PredictResponse:
|
135 |
+
"""Predict the target class for given features using the trained model."""
|
136 |
+
if state.model is None:
|
137 |
raise HTTPException(status_code=400, detail="Model not trained yet")
|
|
|
138 |
try:
|
139 |
encoded_features = {}
|
140 |
for column, value in features.features.items():
|
141 |
+
if column in state.feature_encoders:
|
142 |
+
encoded_features[column] = state.feature_encoders[column].transform([value])[0]
|
|
|
143 |
X = pd.DataFrame([encoded_features])
|
144 |
+
prediction = state.model.predict(X)
|
145 |
+
prediction_proba = state.model.predict_proba(X)
|
146 |
+
predicted_class = state.target_encoder.inverse_transform(prediction)[0]
|
|
|
|
|
|
|
147 |
class_probabilities = {
|
148 |
+
state.target_encoder.inverse_transform([i])[0]: float(prob)
|
149 |
for i, prob in enumerate(prediction_proba[0])
|
150 |
}
|
151 |
+
return PredictResponse(
|
152 |
+
prediction=predicted_class,
|
153 |
+
probabilities=class_probabilities
|
154 |
+
)
|
155 |
+
except Exception as e:
|
156 |
+
raise HTTPException(status_code=500, detail=str(e))
|
157 |
+
|
158 |
+
from fastapi.responses import StreamingResponse
|
159 |
+
import matplotlib.pyplot as plt
|
160 |
+
|
161 |
+
@app.get("/api/plot/confusion-matrix", tags=["Model"], summary="Confusion Matrix Plot")
|
162 |
+
async def plot_confusion_matrix(state: ModelState = Depends(get_model_state)):
|
163 |
+
"""Return a PNG image of the confusion matrix for the last test set."""
|
164 |
+
if state.model is None or state.X_test is None or state.y_test is None:
|
165 |
+
raise HTTPException(status_code=400, detail="Model not trained or no test data available.")
|
166 |
+
y_pred = state.model.predict(state.X_test)
|
167 |
+
cm = confusion_matrix(state.y_test, y_pred)
|
168 |
+
fig, ax = plt.subplots(figsize=(5, 4))
|
169 |
+
cax = ax.matshow(cm, cmap=plt.cm.Blues)
|
170 |
+
plt.title('Confusion Matrix')
|
171 |
+
plt.xlabel('Predicted')
|
172 |
+
plt.ylabel('Actual')
|
173 |
+
plt.colorbar(cax)
|
174 |
+
classes = state.target_encoder.classes_ if state.target_encoder else []
|
175 |
+
ax.set_xticks(np.arange(len(classes)))
|
176 |
+
ax.set_yticks(np.arange(len(classes)))
|
177 |
+
ax.set_xticklabels(classes, rotation=45, ha="left")
|
178 |
+
ax.set_yticklabels(classes)
|
179 |
+
for (i, j), z in np.ndenumerate(cm):
|
180 |
+
ax.text(j, i, str(z), ha='center', va='center', color='red')
|
181 |
+
plt.tight_layout()
|
182 |
+
buf = io.BytesIO()
|
183 |
+
plt.savefig(buf, format='png')
|
184 |
+
plt.close(fig)
|
185 |
+
buf.seek(0)
|
186 |
+
return StreamingResponse(buf, media_type="image/png")
|
187 |
+
|
188 |
+
@app.get("/api/plot/feature-log-prob", tags=["Model"], summary="Feature Log Probability Heatmap")
|
189 |
+
async def plot_feature_log_prob(state: ModelState = Depends(get_model_state)):
|
190 |
+
"""Return a PNG heatmap of feature log probabilities for each class."""
|
191 |
+
if state.model is None or state.target_encoder is None:
|
192 |
+
raise HTTPException(status_code=400, detail="Model not trained.")
|
193 |
+
try:
|
194 |
+
import matplotlib.pyplot as plt
|
195 |
+
import seaborn as sns
|
196 |
+
feature_names = list(state.feature_encoders.keys())
|
197 |
+
class_names = list(state.target_encoder.classes_)
|
198 |
+
# CategoricalNB: feature_log_prob_ shape (n_classes, n_features, n_categories)
|
199 |
+
# We'll plot for each feature, the log prob for each class and each value
|
200 |
+
fig, axes = plt.subplots(len(feature_names), 1, figsize=(8, 4 * len(feature_names)))
|
201 |
+
if len(feature_names) == 1:
|
202 |
+
axes = [axes]
|
203 |
+
for idx, feature in enumerate(feature_names):
|
204 |
+
encoder = state.feature_encoders[feature]
|
205 |
+
categories = encoder.classes_
|
206 |
+
data = []
|
207 |
+
for class_idx, class_name in enumerate(class_names):
|
208 |
+
# For each class, get the log prob for each value of this feature
|
209 |
+
log_probs = state.model.feature_log_prob_[class_idx, idx, :]
|
210 |
+
data.append(log_probs)
|
211 |
+
data = np.array(data)
|
212 |
+
ax = axes[idx]
|
213 |
+
sns.heatmap(data, annot=True, fmt=".2f", cmap="Blues", xticklabels=categories, yticklabels=class_names, ax=ax)
|
214 |
+
ax.set_title(f'Log Probabilities for Feature: {feature}')
|
215 |
+
ax.set_xlabel('Feature Value')
|
216 |
+
ax.set_ylabel('Class')
|
217 |
+
plt.tight_layout()
|
218 |
+
buf = io.BytesIO()
|
219 |
+
plt.savefig(buf, format='png')
|
220 |
+
plt.close(fig)
|
221 |
+
buf.seek(0)
|
222 |
+
return StreamingResponse(buf, media_type="image/png")
|
223 |
except Exception as e:
|
224 |
raise HTTPException(status_code=500, detail=str(e))
|
225 |
|