xuanzang commited on
Commit
d1a05cc
·
1 Parent(s): adead6c

Enhance FastAPI application with model training and prediction features, including detailed response models and health check endpoint.

Browse files
Files changed (1) hide show
  1. app.py +158 -71
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, UploadFile, Form, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
  import pandas as pd
@@ -9,10 +9,14 @@ from sklearn.model_selection import train_test_split
9
  from sklearn.metrics import confusion_matrix
10
  import json
11
  import io
12
- from typing import Dict, List, Optional
13
- from pydantic import BaseModel
14
 
15
- app = FastAPI()
 
 
 
 
16
 
17
  app.add_middleware(
18
  CORSMiddleware,
@@ -22,117 +26,200 @@ app.add_middleware(
22
  allow_headers=["*"],
23
  )
24
 
25
- model = None
26
- feature_encoders: Dict[str, LabelEncoder] = {}
27
- target_encoder: Optional[LabelEncoder] = None
28
-
29
  class TrainOptions(BaseModel):
30
- target_column: str
31
- feature_columns: List[str]
32
 
33
  class PredictionFeatures(BaseModel):
34
- features: Dict[str, str]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- @app.get("/api/health")
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  async def health_check():
 
38
  return {"status": "healthy"}
39
 
40
- @app.post("/api/upload")
41
- async def upload_csv(file: UploadFile):
 
42
  if not file.filename.endswith('.csv'):
43
  raise HTTPException(status_code=400, detail="Only CSV files are allowed")
44
-
45
  try:
46
  contents = await file.read()
47
  df = pd.read_csv(io.StringIO(contents.decode()))
48
-
49
  columns = df.columns.tolist()
50
  column_types = {col: str(df[col].dtype) for col in columns}
51
-
52
  unique_values = {col: df[col].unique().tolist() for col in columns}
53
-
54
  for col, values in unique_values.items():
55
  unique_values[col] = [v.item() if isinstance(v, np.generic) else v for v in values]
56
-
57
- return {
58
- "message": "File uploaded successfully",
59
- "columns": columns,
60
- "column_types": column_types,
61
- "unique_values": unique_values,
62
- "row_count": len(df)
63
- }
64
  except Exception as e:
65
  raise HTTPException(status_code=500, detail=str(e))
66
 
67
- @app.post("/api/train")
68
- async def train_model(file: UploadFile, options: str = Form(...)):
69
- global model, feature_encoders, target_encoder
70
-
 
 
 
71
  try:
72
  train_options = json.loads(options)
73
  target_column = train_options["target_column"]
74
  feature_columns = train_options["feature_columns"]
75
-
76
  contents = await file.read()
77
  df = pd.read_csv(io.StringIO(contents.decode()))
78
-
79
  X = pd.DataFrame()
80
- feature_encoders = {}
81
  for column in feature_columns:
82
  encoder = LabelEncoder()
83
  X[column] = encoder.fit_transform(df[column])
84
- feature_encoders[column] = encoder
85
-
86
- target_encoder = LabelEncoder()
87
- y = target_encoder.fit_transform(df[target_column])
88
-
89
  X_train, X_test, y_train, y_test = train_test_split(
90
  X, y, test_size=0.2, random_state=42
91
  )
92
-
93
- model = CategoricalNB()
94
- model.fit(X_train, y_train)
95
-
96
- accuracy = float(model.score(X_test, y_test))
97
-
98
- return {
99
- "message": "Model trained successfully",
100
- "accuracy": accuracy,
101
- "target_classes": target_encoder.classes_.tolist()
102
- }
103
-
104
  except Exception as e:
105
  raise HTTPException(status_code=500, detail=str(e))
106
 
107
- @app.post("/api/predict")
108
- async def predict(features: PredictionFeatures):
109
- global model, feature_encoders, target_encoder
110
-
111
- if model is None:
 
 
112
  raise HTTPException(status_code=400, detail="Model not trained yet")
113
-
114
  try:
115
  encoded_features = {}
116
  for column, value in features.features.items():
117
- if column in feature_encoders:
118
- encoded_features[column] = feature_encoders[column].transform([value])[0]
119
-
120
  X = pd.DataFrame([encoded_features])
121
-
122
- prediction = model.predict(X)
123
- prediction_proba = model.predict_proba(X)
124
-
125
- predicted_class = target_encoder.inverse_transform(prediction)[0]
126
-
127
  class_probabilities = {
128
- target_encoder.inverse_transform([i])[0]: float(prob)
129
  for i, prob in enumerate(prediction_proba[0])
130
  }
131
-
132
- return {
133
- "prediction": predicted_class,
134
- "probabilities": class_probabilities
135
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  except Exception as e:
137
  raise HTTPException(status_code=500, detail=str(e))
138
 
 
1
+ from fastapi import FastAPI, UploadFile, Form, HTTPException, Depends, status
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
  import pandas as pd
 
9
  from sklearn.metrics import confusion_matrix
10
  import json
11
  import io
12
+ from typing import Dict, List, Optional, Any
13
+ from pydantic import BaseModel, Field
14
 
15
+ app = FastAPI(
16
+ title="Categorical Naive Bayes API",
17
+ description="API for uploading CSVs, training a Categorical Naive Bayes model, and making predictions.",
18
+ version="1.0.0"
19
+ )
20
 
21
  app.add_middleware(
22
  CORSMiddleware,
 
26
  allow_headers=["*"],
27
  )
28
 
 
 
 
 
29
  class TrainOptions(BaseModel):
30
+ target_column: str = Field(..., description="The name of the target column.")
31
+ feature_columns: List[str] = Field(..., description="List of feature column names.")
32
 
33
  class PredictionFeatures(BaseModel):
34
+ features: Dict[str, str] = Field(..., description="Dictionary of feature values for prediction.")
35
+
36
+ class UploadResponse(BaseModel):
37
+ message: str
38
+ columns: List[str]
39
+ column_types: Dict[str, str]
40
+ unique_values: Dict[str, List[Any]]
41
+ row_count: int
42
+
43
+ class TrainResponse(BaseModel):
44
+ message: str
45
+ accuracy: float
46
+ target_classes: List[str]
47
+
48
+ class PredictResponse(BaseModel):
49
+ prediction: str
50
+ probabilities: Dict[str, float]
51
 
52
+ class ModelState:
53
+ def __init__(self):
54
+ self.model: Optional[CategoricalNB] = None
55
+ self.feature_encoders: Dict[str, LabelEncoder] = {}
56
+ self.target_encoder: Optional[LabelEncoder] = None
57
+ self.X_test: Optional[pd.DataFrame] = None
58
+ self.y_test: Optional[np.ndarray] = None
59
+
60
+ model_state = ModelState()
61
+
62
+ def get_model_state():
63
+ return model_state
64
+
65
+ @app.get("/api/health", tags=["Health"], summary="Health Check", response_model=Dict[str, str])
66
  async def health_check():
67
+ """Check API health."""
68
  return {"status": "healthy"}
69
 
70
+ @app.post("/api/upload", tags=["Data"], summary="Upload CSV File", response_model=UploadResponse, status_code=status.HTTP_200_OK)
71
+ async def upload_csv(file: UploadFile) -> UploadResponse:
72
+ """Upload a CSV file and get metadata about its columns."""
73
  if not file.filename.endswith('.csv'):
74
  raise HTTPException(status_code=400, detail="Only CSV files are allowed")
 
75
  try:
76
  contents = await file.read()
77
  df = pd.read_csv(io.StringIO(contents.decode()))
 
78
  columns = df.columns.tolist()
79
  column_types = {col: str(df[col].dtype) for col in columns}
 
80
  unique_values = {col: df[col].unique().tolist() for col in columns}
 
81
  for col, values in unique_values.items():
82
  unique_values[col] = [v.item() if isinstance(v, np.generic) else v for v in values]
83
+ return UploadResponse(
84
+ message="File uploaded successfully",
85
+ columns=columns,
86
+ column_types=column_types,
87
+ unique_values=unique_values,
88
+ row_count=len(df)
89
+ )
 
90
  except Exception as e:
91
  raise HTTPException(status_code=500, detail=str(e))
92
 
93
+ @app.post("/api/train", tags=["Model"], summary="Train Model", response_model=TrainResponse, status_code=status.HTTP_200_OK)
94
+ async def train_model(
95
+ file: UploadFile,
96
+ options: str = Form(...),
97
+ state: ModelState = Depends(get_model_state)
98
+ ) -> TrainResponse:
99
+ """Train a Categorical Naive Bayes model on the uploaded CSV."""
100
  try:
101
  train_options = json.loads(options)
102
  target_column = train_options["target_column"]
103
  feature_columns = train_options["feature_columns"]
 
104
  contents = await file.read()
105
  df = pd.read_csv(io.StringIO(contents.decode()))
 
106
  X = pd.DataFrame()
107
+ state.feature_encoders = {}
108
  for column in feature_columns:
109
  encoder = LabelEncoder()
110
  X[column] = encoder.fit_transform(df[column])
111
+ state.feature_encoders[column] = encoder
112
+ state.target_encoder = LabelEncoder()
113
+ y = state.target_encoder.fit_transform(df[target_column])
 
 
114
  X_train, X_test, y_train, y_test = train_test_split(
115
  X, y, test_size=0.2, random_state=42
116
  )
117
+ state.model = CategoricalNB()
118
+ state.model.fit(X_train, y_train)
119
+ accuracy = float(state.model.score(X_test, y_test))
120
+ state.X_test = X_test
121
+ state.y_test = y_test
122
+ return TrainResponse(
123
+ message="Model trained successfully",
124
+ accuracy=accuracy,
125
+ target_classes=list(state.target_encoder.classes_)
126
+ )
 
 
127
  except Exception as e:
128
  raise HTTPException(status_code=500, detail=str(e))
129
 
130
+ @app.post("/api/predict", tags=["Model"], summary="Predict", response_model=PredictResponse, status_code=status.HTTP_200_OK)
131
+ async def predict(
132
+ features: PredictionFeatures,
133
+ state: ModelState = Depends(get_model_state)
134
+ ) -> PredictResponse:
135
+ """Predict the target class for given features using the trained model."""
136
+ if state.model is None:
137
  raise HTTPException(status_code=400, detail="Model not trained yet")
 
138
  try:
139
  encoded_features = {}
140
  for column, value in features.features.items():
141
+ if column in state.feature_encoders:
142
+ encoded_features[column] = state.feature_encoders[column].transform([value])[0]
 
143
  X = pd.DataFrame([encoded_features])
144
+ prediction = state.model.predict(X)
145
+ prediction_proba = state.model.predict_proba(X)
146
+ predicted_class = state.target_encoder.inverse_transform(prediction)[0]
 
 
 
147
  class_probabilities = {
148
+ state.target_encoder.inverse_transform([i])[0]: float(prob)
149
  for i, prob in enumerate(prediction_proba[0])
150
  }
151
+ return PredictResponse(
152
+ prediction=predicted_class,
153
+ probabilities=class_probabilities
154
+ )
155
+ except Exception as e:
156
+ raise HTTPException(status_code=500, detail=str(e))
157
+
158
+ from fastapi.responses import StreamingResponse
159
+ import matplotlib.pyplot as plt
160
+
161
+ @app.get("/api/plot/confusion-matrix", tags=["Model"], summary="Confusion Matrix Plot")
162
+ async def plot_confusion_matrix(state: ModelState = Depends(get_model_state)):
163
+ """Return a PNG image of the confusion matrix for the last test set."""
164
+ if state.model is None or state.X_test is None or state.y_test is None:
165
+ raise HTTPException(status_code=400, detail="Model not trained or no test data available.")
166
+ y_pred = state.model.predict(state.X_test)
167
+ cm = confusion_matrix(state.y_test, y_pred)
168
+ fig, ax = plt.subplots(figsize=(5, 4))
169
+ cax = ax.matshow(cm, cmap=plt.cm.Blues)
170
+ plt.title('Confusion Matrix')
171
+ plt.xlabel('Predicted')
172
+ plt.ylabel('Actual')
173
+ plt.colorbar(cax)
174
+ classes = state.target_encoder.classes_ if state.target_encoder else []
175
+ ax.set_xticks(np.arange(len(classes)))
176
+ ax.set_yticks(np.arange(len(classes)))
177
+ ax.set_xticklabels(classes, rotation=45, ha="left")
178
+ ax.set_yticklabels(classes)
179
+ for (i, j), z in np.ndenumerate(cm):
180
+ ax.text(j, i, str(z), ha='center', va='center', color='red')
181
+ plt.tight_layout()
182
+ buf = io.BytesIO()
183
+ plt.savefig(buf, format='png')
184
+ plt.close(fig)
185
+ buf.seek(0)
186
+ return StreamingResponse(buf, media_type="image/png")
187
+
188
+ @app.get("/api/plot/feature-log-prob", tags=["Model"], summary="Feature Log Probability Heatmap")
189
+ async def plot_feature_log_prob(state: ModelState = Depends(get_model_state)):
190
+ """Return a PNG heatmap of feature log probabilities for each class."""
191
+ if state.model is None or state.target_encoder is None:
192
+ raise HTTPException(status_code=400, detail="Model not trained.")
193
+ try:
194
+ import matplotlib.pyplot as plt
195
+ import seaborn as sns
196
+ feature_names = list(state.feature_encoders.keys())
197
+ class_names = list(state.target_encoder.classes_)
198
+ # CategoricalNB: feature_log_prob_ shape (n_classes, n_features, n_categories)
199
+ # We'll plot for each feature, the log prob for each class and each value
200
+ fig, axes = plt.subplots(len(feature_names), 1, figsize=(8, 4 * len(feature_names)))
201
+ if len(feature_names) == 1:
202
+ axes = [axes]
203
+ for idx, feature in enumerate(feature_names):
204
+ encoder = state.feature_encoders[feature]
205
+ categories = encoder.classes_
206
+ data = []
207
+ for class_idx, class_name in enumerate(class_names):
208
+ # For each class, get the log prob for each value of this feature
209
+ log_probs = state.model.feature_log_prob_[class_idx, idx, :]
210
+ data.append(log_probs)
211
+ data = np.array(data)
212
+ ax = axes[idx]
213
+ sns.heatmap(data, annot=True, fmt=".2f", cmap="Blues", xticklabels=categories, yticklabels=class_names, ax=ax)
214
+ ax.set_title(f'Log Probabilities for Feature: {feature}')
215
+ ax.set_xlabel('Feature Value')
216
+ ax.set_ylabel('Class')
217
+ plt.tight_layout()
218
+ buf = io.BytesIO()
219
+ plt.savefig(buf, format='png')
220
+ plt.close(fig)
221
+ buf.seek(0)
222
+ return StreamingResponse(buf, media_type="image/png")
223
  except Exception as e:
224
  raise HTTPException(status_code=500, detail=str(e))
225