Spaces:
Sleeping
Sleeping
- __pycache__/app.cpython-313.pyc +0 -0
- app.py +300 -63
- requirements.txt +10 -8
__pycache__/app.cpython-313.pyc
ADDED
Binary file (21.2 kB). View file
|
|
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
from fastapi import FastAPI, UploadFile, Form, HTTPException, Depends, status
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
-
from fastapi.responses import JSONResponse
|
4 |
import pandas as pd
|
5 |
import numpy as np
|
6 |
from sklearn.naive_bayes import CategoricalNB
|
@@ -11,6 +11,9 @@ import json
|
|
11 |
import io
|
12 |
from typing import Dict, List, Optional, Any
|
13 |
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
14 |
|
15 |
app = FastAPI(
|
16 |
title="Categorical Naive Bayes API",
|
@@ -68,161 +71,395 @@ async def health_check():
|
|
68 |
return {"status": "healthy"}
|
69 |
|
70 |
@app.post("/api/upload", tags=["Data"], summary="Upload CSV File", response_model=UploadResponse, status_code=status.HTTP_200_OK)
|
71 |
-
async def upload_csv(
|
|
|
|
|
72 |
"""Upload a CSV file and get metadata about its columns."""
|
73 |
-
if not file.filename.endswith('.csv'):
|
74 |
-
raise HTTPException(
|
|
|
|
|
|
|
|
|
75 |
try:
|
76 |
contents = await file.read()
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
columns = df.columns.tolist()
|
79 |
column_types = {col: str(df[col].dtype) for col in columns}
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
for col, values in unique_values.items():
|
82 |
unique_values[col] = [v.item() if isinstance(v, np.generic) else v for v in values]
|
|
|
83 |
return UploadResponse(
|
84 |
-
message="File uploaded successfully",
|
85 |
columns=columns,
|
86 |
column_types=column_types,
|
87 |
unique_values=unique_values,
|
88 |
row_count=len(df)
|
89 |
)
|
|
|
|
|
|
|
90 |
except Exception as e:
|
91 |
-
raise HTTPException(
|
|
|
|
|
|
|
92 |
|
93 |
@app.post("/api/train", tags=["Model"], summary="Train Model", response_model=TrainResponse, status_code=status.HTTP_200_OK)
|
94 |
async def train_model(
|
95 |
-
file: UploadFile,
|
96 |
-
options:
|
97 |
state: ModelState = Depends(get_model_state)
|
98 |
) -> TrainResponse:
|
99 |
-
"""Train a Categorical Naive Bayes model on the uploaded CSV.
|
|
|
|
|
|
|
|
|
|
|
100 |
try:
|
101 |
-
train_options = json.loads(options)
|
102 |
-
target_column = train_options["target_column"]
|
103 |
-
feature_columns = train_options["feature_columns"]
|
104 |
contents = await file.read()
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
X = pd.DataFrame()
|
107 |
state.feature_encoders = {}
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
encoder = LabelEncoder()
|
110 |
X[column] = encoder.fit_transform(df[column])
|
111 |
state.feature_encoders[column] = encoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
state.target_encoder = LabelEncoder()
|
113 |
-
y = state.target_encoder.fit_transform(df[target_column])
|
|
|
|
|
114 |
X_train, X_test, y_train, y_test = train_test_split(
|
115 |
-
X, y, test_size=0.2, random_state=42
|
116 |
)
|
|
|
|
|
117 |
state.model = CategoricalNB()
|
118 |
state.model.fit(X_train, y_train)
|
119 |
accuracy = float(state.model.score(X_test, y_test))
|
120 |
state.X_test = X_test
|
121 |
state.y_test = y_test
|
|
|
122 |
return TrainResponse(
|
123 |
message="Model trained successfully",
|
124 |
accuracy=accuracy,
|
125 |
target_classes=list(state.target_encoder.classes_)
|
126 |
)
|
|
|
|
|
127 |
except Exception as e:
|
128 |
-
raise HTTPException(
|
|
|
|
|
|
|
129 |
|
130 |
@app.post("/api/predict", tags=["Model"], summary="Predict", response_model=PredictResponse, status_code=status.HTTP_200_OK)
|
131 |
async def predict(
|
132 |
features: PredictionFeatures,
|
133 |
state: ModelState = Depends(get_model_state)
|
134 |
) -> PredictResponse:
|
135 |
-
"""Predict the target class for given features using the trained model.
|
|
|
|
|
|
|
|
|
136 |
if state.model is None:
|
137 |
-
raise HTTPException(
|
|
|
|
|
|
|
|
|
138 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
encoded_features = {}
|
140 |
for column, value in features.features.items():
|
141 |
-
|
142 |
encoded_features[column] = state.feature_encoders[column].transform([value])[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
X = pd.DataFrame([encoded_features])
|
144 |
prediction = state.model.predict(X)
|
145 |
prediction_proba = state.model.predict_proba(X)
|
146 |
predicted_class = state.target_encoder.inverse_transform(prediction)[0]
|
|
|
|
|
147 |
class_probabilities = {
|
148 |
state.target_encoder.inverse_transform([i])[0]: float(prob)
|
149 |
for i, prob in enumerate(prediction_proba[0])
|
150 |
}
|
|
|
151 |
return PredictResponse(
|
152 |
prediction=predicted_class,
|
153 |
probabilities=class_probabilities
|
154 |
)
|
|
|
|
|
155 |
except Exception as e:
|
156 |
-
raise HTTPException(
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
|
161 |
-
@app.get(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
async def plot_confusion_matrix(state: ModelState = Depends(get_model_state)):
|
163 |
"""Return a PNG image of the confusion matrix for the last test set."""
|
164 |
if state.model is None or state.X_test is None or state.y_test is None:
|
165 |
-
raise HTTPException(
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
async def plot_feature_log_prob(state: ModelState = Depends(get_model_state)):
|
190 |
"""Return a PNG heatmap of feature log probabilities for each class."""
|
191 |
if state.model is None or state.target_encoder is None:
|
192 |
-
raise HTTPException(
|
|
|
|
|
|
|
|
|
193 |
try:
|
194 |
-
import matplotlib.pyplot as plt
|
195 |
-
import seaborn as sns
|
196 |
feature_names = list(state.feature_encoders.keys())
|
197 |
class_names = list(state.target_encoder.classes_)
|
198 |
-
|
199 |
-
#
|
200 |
-
|
|
|
201 |
if len(feature_names) == 1:
|
202 |
axes = [axes]
|
|
|
203 |
for idx, feature in enumerate(feature_names):
|
204 |
encoder = state.feature_encoders[feature]
|
205 |
categories = encoder.classes_
|
206 |
data = []
|
|
|
207 |
for class_idx, class_name in enumerate(class_names):
|
208 |
# For each class, get the log prob for each value of this feature
|
209 |
log_probs = state.model.feature_log_prob_[class_idx, idx, :]
|
210 |
data.append(log_probs)
|
|
|
211 |
data = np.array(data)
|
212 |
ax = axes[idx]
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
ax.set_title(f'Log Probabilities for Feature: {feature}')
|
215 |
ax.set_xlabel('Feature Value')
|
216 |
ax.set_ylabel('Class')
|
|
|
217 |
plt.tight_layout()
|
|
|
|
|
218 |
buf = io.BytesIO()
|
219 |
-
plt.savefig(buf, format='png')
|
220 |
plt.close(fig)
|
221 |
buf.seek(0)
|
222 |
-
|
|
|
|
|
|
|
|
|
223 |
except Exception as e:
|
224 |
-
raise HTTPException(
|
|
|
|
|
|
|
225 |
|
226 |
if __name__ == "__main__":
|
227 |
import uvicorn
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, UploadFile, Form, HTTPException, Depends, status, File
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
4 |
import pandas as pd
|
5 |
import numpy as np
|
6 |
from sklearn.naive_bayes import CategoricalNB
|
|
|
11 |
import io
|
12 |
from typing import Dict, List, Optional, Any
|
13 |
from pydantic import BaseModel, Field
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
import seaborn as sns
|
16 |
+
from fastapi.encoders import jsonable_encoder
|
17 |
|
18 |
app = FastAPI(
|
19 |
title="Categorical Naive Bayes API",
|
|
|
71 |
return {"status": "healthy"}
|
72 |
|
73 |
@app.post("/api/upload", tags=["Data"], summary="Upload CSV File", response_model=UploadResponse, status_code=status.HTTP_200_OK)
|
74 |
+
async def upload_csv(
|
75 |
+
file: UploadFile = File(..., description="CSV file to upload")
|
76 |
+
) -> UploadResponse:
|
77 |
"""Upload a CSV file and get metadata about its columns."""
|
78 |
+
if not file.filename or not file.filename.lower().endswith('.csv'):
|
79 |
+
raise HTTPException(
|
80 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
81 |
+
detail="Only CSV files are allowed"
|
82 |
+
)
|
83 |
+
|
84 |
try:
|
85 |
contents = await file.read()
|
86 |
+
# Check if the file content is valid
|
87 |
+
if len(contents) == 0:
|
88 |
+
raise HTTPException(
|
89 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
90 |
+
detail="Uploaded file is empty"
|
91 |
+
)
|
92 |
+
|
93 |
+
# Try to parse the CSV
|
94 |
+
try:
|
95 |
+
df = pd.read_csv(io.StringIO(contents.decode('utf-8')))
|
96 |
+
except UnicodeDecodeError:
|
97 |
+
# Try another encoding if UTF-8 fails
|
98 |
+
try:
|
99 |
+
df = pd.read_csv(io.StringIO(contents.decode('latin-1')))
|
100 |
+
except Exception:
|
101 |
+
raise HTTPException(
|
102 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
103 |
+
detail="Unable to decode CSV file. Please ensure it's properly formatted."
|
104 |
+
)
|
105 |
+
except Exception as e:
|
106 |
+
raise HTTPException(
|
107 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
108 |
+
detail=f"Error parsing CSV: {str(e)}"
|
109 |
+
)
|
110 |
+
|
111 |
+
if df.empty:
|
112 |
+
raise HTTPException(
|
113 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
114 |
+
detail="CSV file contains no data"
|
115 |
+
)
|
116 |
+
|
117 |
+
# Process the data
|
118 |
columns = df.columns.tolist()
|
119 |
column_types = {col: str(df[col].dtype) for col in columns}
|
120 |
+
|
121 |
+
# Limit the number of unique values to prevent excessive response sizes
|
122 |
+
unique_values = {}
|
123 |
+
for col in columns:
|
124 |
+
unique_vals = df[col].unique().tolist()
|
125 |
+
# Limit to 100 values max to prevent excessive response size
|
126 |
+
if len(unique_vals) > 100:
|
127 |
+
unique_values[col] = unique_vals[:100] + ["... (truncated)"]
|
128 |
+
else:
|
129 |
+
unique_values[col] = unique_vals
|
130 |
+
|
131 |
+
# Convert NumPy objects to Python native types
|
132 |
for col, values in unique_values.items():
|
133 |
unique_values[col] = [v.item() if isinstance(v, np.generic) else v for v in values]
|
134 |
+
|
135 |
return UploadResponse(
|
136 |
+
message="File uploaded and processed successfully",
|
137 |
columns=columns,
|
138 |
column_types=column_types,
|
139 |
unique_values=unique_values,
|
140 |
row_count=len(df)
|
141 |
)
|
142 |
+
except HTTPException:
|
143 |
+
# Re-raise HTTP exceptions
|
144 |
+
raise
|
145 |
except Exception as e:
|
146 |
+
raise HTTPException(
|
147 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
148 |
+
detail=f"An unexpected error occurred: {str(e)}"
|
149 |
+
)
|
150 |
|
151 |
@app.post("/api/train", tags=["Model"], summary="Train Model", response_model=TrainResponse, status_code=status.HTTP_200_OK)
|
152 |
async def train_model(
|
153 |
+
file: UploadFile = File(..., description="CSV file to train on"),
|
154 |
+
options: TrainOptions = Depends(),
|
155 |
state: ModelState = Depends(get_model_state)
|
156 |
) -> TrainResponse:
|
157 |
+
"""Train a Categorical Naive Bayes model on the uploaded CSV.
|
158 |
+
|
159 |
+
Parameters:
|
160 |
+
- file: CSV file with the training data
|
161 |
+
- options: Training options specifying target column and feature columns
|
162 |
+
"""
|
163 |
try:
|
|
|
|
|
|
|
164 |
contents = await file.read()
|
165 |
+
|
166 |
+
try:
|
167 |
+
df = pd.read_csv(io.StringIO(contents.decode('utf-8')))
|
168 |
+
except UnicodeDecodeError:
|
169 |
+
df = pd.read_csv(io.StringIO(contents.decode('latin-1')))
|
170 |
+
|
171 |
+
# Validate columns exist in the DataFrame
|
172 |
+
missing_columns = [col for col in [options.target_column] + options.feature_columns
|
173 |
+
if col not in df.columns]
|
174 |
+
if missing_columns:
|
175 |
+
raise HTTPException(
|
176 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
177 |
+
detail=f"Columns not found in CSV: {', '.join(missing_columns)}"
|
178 |
+
)
|
179 |
+
|
180 |
+
# Initialize data structures
|
181 |
X = pd.DataFrame()
|
182 |
state.feature_encoders = {}
|
183 |
+
|
184 |
+
# Encode features
|
185 |
+
for column in options.feature_columns:
|
186 |
+
if df[column].isna().any():
|
187 |
+
raise HTTPException(
|
188 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
189 |
+
detail=f"Column '{column}' contains missing values. Please preprocess your data."
|
190 |
+
)
|
191 |
encoder = LabelEncoder()
|
192 |
X[column] = encoder.fit_transform(df[column])
|
193 |
state.feature_encoders[column] = encoder
|
194 |
+
|
195 |
+
# Encode target
|
196 |
+
if df[options.target_column].isna().any():
|
197 |
+
raise HTTPException(
|
198 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
199 |
+
detail=f"Target column '{options.target_column}' contains missing values."
|
200 |
+
)
|
201 |
+
|
202 |
state.target_encoder = LabelEncoder()
|
203 |
+
y = state.target_encoder.fit_transform(df[options.target_column])
|
204 |
+
|
205 |
+
# Train/test split
|
206 |
X_train, X_test, y_train, y_test = train_test_split(
|
207 |
+
X, y, test_size=0.2, random_state=42, stratify=y if len(np.unique(y)) > 1 else None
|
208 |
)
|
209 |
+
|
210 |
+
# Train the model
|
211 |
state.model = CategoricalNB()
|
212 |
state.model.fit(X_train, y_train)
|
213 |
accuracy = float(state.model.score(X_test, y_test))
|
214 |
state.X_test = X_test
|
215 |
state.y_test = y_test
|
216 |
+
|
217 |
return TrainResponse(
|
218 |
message="Model trained successfully",
|
219 |
accuracy=accuracy,
|
220 |
target_classes=list(state.target_encoder.classes_)
|
221 |
)
|
222 |
+
except HTTPException:
|
223 |
+
raise
|
224 |
except Exception as e:
|
225 |
+
raise HTTPException(
|
226 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
227 |
+
detail=f"Model training failed: {str(e)}"
|
228 |
+
)
|
229 |
|
230 |
@app.post("/api/predict", tags=["Model"], summary="Predict", response_model=PredictResponse, status_code=status.HTTP_200_OK)
|
231 |
async def predict(
|
232 |
features: PredictionFeatures,
|
233 |
state: ModelState = Depends(get_model_state)
|
234 |
) -> PredictResponse:
|
235 |
+
"""Predict the target class for given features using the trained model.
|
236 |
+
|
237 |
+
Parameters:
|
238 |
+
- features: Dictionary of feature values for prediction
|
239 |
+
"""
|
240 |
if state.model is None:
|
241 |
+
raise HTTPException(
|
242 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
243 |
+
detail="Model not trained yet. Please train a model before making predictions."
|
244 |
+
)
|
245 |
+
|
246 |
try:
|
247 |
+
# Validate that all required features are provided
|
248 |
+
required_features = set(state.feature_encoders.keys())
|
249 |
+
provided_features = set(features.features.keys())
|
250 |
+
|
251 |
+
missing_features = required_features - provided_features
|
252 |
+
if missing_features:
|
253 |
+
raise HTTPException(
|
254 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
255 |
+
detail=f"Missing required features: {', '.join(missing_features)}"
|
256 |
+
)
|
257 |
+
|
258 |
+
# Validate extra features
|
259 |
+
extra_features = provided_features - required_features
|
260 |
+
if extra_features:
|
261 |
+
raise HTTPException(
|
262 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
263 |
+
detail=f"Unexpected features provided: {', '.join(extra_features)}"
|
264 |
+
)
|
265 |
+
|
266 |
+
# Encode features
|
267 |
encoded_features = {}
|
268 |
for column, value in features.features.items():
|
269 |
+
try:
|
270 |
encoded_features[column] = state.feature_encoders[column].transform([value])[0]
|
271 |
+
except ValueError:
|
272 |
+
# Handle unknown category values
|
273 |
+
raise HTTPException(
|
274 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
275 |
+
detail=f"Unknown value '{value}' for feature '{column}'. Allowed values: {', '.join(map(str, state.feature_encoders[column].classes_))}"
|
276 |
+
)
|
277 |
+
|
278 |
+
# Make prediction
|
279 |
X = pd.DataFrame([encoded_features])
|
280 |
prediction = state.model.predict(X)
|
281 |
prediction_proba = state.model.predict_proba(X)
|
282 |
predicted_class = state.target_encoder.inverse_transform(prediction)[0]
|
283 |
+
|
284 |
+
# Generate probabilities
|
285 |
class_probabilities = {
|
286 |
state.target_encoder.inverse_transform([i])[0]: float(prob)
|
287 |
for i, prob in enumerate(prediction_proba[0])
|
288 |
}
|
289 |
+
|
290 |
return PredictResponse(
|
291 |
prediction=predicted_class,
|
292 |
probabilities=class_probabilities
|
293 |
)
|
294 |
+
except HTTPException:
|
295 |
+
raise
|
296 |
except Exception as e:
|
297 |
+
raise HTTPException(
|
298 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
299 |
+
detail=f"Prediction failed: {str(e)}"
|
300 |
+
)
|
301 |
|
302 |
+
@app.get(
|
303 |
+
"/api/plot/confusion-matrix",
|
304 |
+
tags=["Model"],
|
305 |
+
summary="Confusion Matrix Plot",
|
306 |
+
response_class=StreamingResponse,
|
307 |
+
responses={
|
308 |
+
200: {
|
309 |
+
"content": {"image/png": {}},
|
310 |
+
"description": "PNG image of confusion matrix"
|
311 |
+
},
|
312 |
+
400: {
|
313 |
+
"description": "Model not trained or no test data available"
|
314 |
+
}
|
315 |
+
}
|
316 |
+
)
|
317 |
async def plot_confusion_matrix(state: ModelState = Depends(get_model_state)):
|
318 |
"""Return a PNG image of the confusion matrix for the last test set."""
|
319 |
if state.model is None or state.X_test is None or state.y_test is None:
|
320 |
+
raise HTTPException(
|
321 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
322 |
+
detail="Model not trained or no test data available."
|
323 |
+
)
|
324 |
+
|
325 |
+
try:
|
326 |
+
y_pred = state.model.predict(state.X_test)
|
327 |
+
cm = confusion_matrix(state.y_test, y_pred)
|
328 |
+
|
329 |
+
# Create plot
|
330 |
+
fig, ax = plt.subplots(figsize=(7, 6))
|
331 |
+
cax = ax.matshow(cm, cmap=plt.cm.Blues)
|
332 |
+
plt.title('Confusion Matrix')
|
333 |
+
plt.xlabel('Predicted')
|
334 |
+
plt.ylabel('Actual')
|
335 |
+
plt.colorbar(cax)
|
336 |
+
|
337 |
+
# Add labels
|
338 |
+
classes = state.target_encoder.classes_ if state.target_encoder else []
|
339 |
+
ax.set_xticks(np.arange(len(classes)))
|
340 |
+
ax.set_yticks(np.arange(len(classes)))
|
341 |
+
ax.set_xticklabels(classes, rotation=45, ha="left")
|
342 |
+
ax.set_yticklabels(classes)
|
343 |
+
|
344 |
+
# Add numbers to the plot
|
345 |
+
for (i, j), z in np.ndenumerate(cm):
|
346 |
+
ax.text(j, i, str(z), ha='center', va='center',
|
347 |
+
color='white' if cm[i, j] > cm.max() / 2 else 'black')
|
348 |
+
|
349 |
+
plt.tight_layout()
|
350 |
+
|
351 |
+
# Save to buffer
|
352 |
+
buf = io.BytesIO()
|
353 |
+
plt.savefig(buf, format='png', dpi=150)
|
354 |
+
plt.close(fig)
|
355 |
+
buf.seek(0)
|
356 |
+
|
357 |
+
# Create response with cache control headers
|
358 |
+
response = StreamingResponse(buf, media_type="image/png")
|
359 |
+
response.headers["Cache-Control"] = "max-age=3600" # Cache for 1 hour
|
360 |
+
return response
|
361 |
+
except Exception as e:
|
362 |
+
raise HTTPException(
|
363 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
364 |
+
detail=f"Failed to generate confusion matrix: {str(e)}"
|
365 |
+
)
|
366 |
+
|
367 |
+
@app.get(
|
368 |
+
"/api/plot/feature-log-prob",
|
369 |
+
tags=["Model"],
|
370 |
+
summary="Feature Log Probability Heatmap",
|
371 |
+
response_class=StreamingResponse,
|
372 |
+
responses={
|
373 |
+
200: {
|
374 |
+
"content": {"image/png": {}},
|
375 |
+
"description": "PNG heatmap of feature log probabilities"
|
376 |
+
},
|
377 |
+
400: {
|
378 |
+
"description": "Model not trained"
|
379 |
+
}
|
380 |
+
}
|
381 |
+
)
|
382 |
async def plot_feature_log_prob(state: ModelState = Depends(get_model_state)):
|
383 |
"""Return a PNG heatmap of feature log probabilities for each class."""
|
384 |
if state.model is None or state.target_encoder is None:
|
385 |
+
raise HTTPException(
|
386 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
387 |
+
detail="Model not trained."
|
388 |
+
)
|
389 |
+
|
390 |
try:
|
|
|
|
|
391 |
feature_names = list(state.feature_encoders.keys())
|
392 |
class_names = list(state.target_encoder.classes_)
|
393 |
+
|
394 |
+
# Calculate plot size based on data
|
395 |
+
fig_height = max(4, 2 * len(feature_names))
|
396 |
+
fig, axes = plt.subplots(len(feature_names), 1, figsize=(10, fig_height))
|
397 |
if len(feature_names) == 1:
|
398 |
axes = [axes]
|
399 |
+
|
400 |
for idx, feature in enumerate(feature_names):
|
401 |
encoder = state.feature_encoders[feature]
|
402 |
categories = encoder.classes_
|
403 |
data = []
|
404 |
+
|
405 |
for class_idx, class_name in enumerate(class_names):
|
406 |
# For each class, get the log prob for each value of this feature
|
407 |
log_probs = state.model.feature_log_prob_[class_idx, idx, :]
|
408 |
data.append(log_probs)
|
409 |
+
|
410 |
data = np.array(data)
|
411 |
ax = axes[idx]
|
412 |
+
|
413 |
+
# Create heatmap
|
414 |
+
sns.heatmap(
|
415 |
+
data,
|
416 |
+
annot=True,
|
417 |
+
fmt=".2f",
|
418 |
+
cmap="Blues",
|
419 |
+
xticklabels=categories,
|
420 |
+
yticklabels=class_names,
|
421 |
+
ax=ax
|
422 |
+
)
|
423 |
+
|
424 |
ax.set_title(f'Log Probabilities for Feature: {feature}')
|
425 |
ax.set_xlabel('Feature Value')
|
426 |
ax.set_ylabel('Class')
|
427 |
+
|
428 |
plt.tight_layout()
|
429 |
+
|
430 |
+
# Save to buffer
|
431 |
buf = io.BytesIO()
|
432 |
+
plt.savefig(buf, format='png', dpi=150)
|
433 |
plt.close(fig)
|
434 |
buf.seek(0)
|
435 |
+
|
436 |
+
# Create response with cache control headers
|
437 |
+
response = StreamingResponse(buf, media_type="image/png")
|
438 |
+
response.headers["Cache-Control"] = "max-age=3600" # Cache for 1 hour
|
439 |
+
return response
|
440 |
except Exception as e:
|
441 |
+
raise HTTPException(
|
442 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
443 |
+
detail=f"Failed to generate feature log probability plot: {str(e)}"
|
444 |
+
)
|
445 |
|
446 |
if __name__ == "__main__":
|
447 |
import uvicorn
|
448 |
+
import os
|
449 |
+
|
450 |
+
# Get port from environment variable or default to 7860 (for HF Spaces)
|
451 |
+
port = int(os.environ.get("PORT", 7860))
|
452 |
+
|
453 |
+
# Configure logging for better visibility
|
454 |
+
log_config = uvicorn.config.LOGGING_CONFIG
|
455 |
+
log_config["formatters"]["access"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s"
|
456 |
+
log_config["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s"
|
457 |
+
|
458 |
+
uvicorn.run(
|
459 |
+
"app:app",
|
460 |
+
host="0.0.0.0",
|
461 |
+
port=port,
|
462 |
+
log_level="info",
|
463 |
+
reload=True,
|
464 |
+
log_config=log_config
|
465 |
+
)
|
requirements.txt
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
-
fastapi
|
2 |
-
uvicorn
|
3 |
-
python-multipart
|
4 |
-
pandas
|
5 |
-
scikit-learn
|
6 |
-
numpy
|
7 |
-
matplotlib
|
8 |
-
|
|
|
|
|
|
1 |
+
fastapi>=0.104.0
|
2 |
+
uvicorn>=0.23.2
|
3 |
+
python-multipart>=0.0.6
|
4 |
+
pandas>=2.0.0
|
5 |
+
scikit-learn>=1.3.0
|
6 |
+
numpy>=1.24.0
|
7 |
+
matplotlib>=3.7.0
|
8 |
+
seaborn>=0.12.2
|
9 |
+
gunicorn>=21.2.0
|
10 |
+
pydantic>=2.0.0
|