xuanzang commited on
Commit
545fc5b
·
1 Parent(s): d1a05cc
Files changed (3) hide show
  1. __pycache__/app.cpython-313.pyc +0 -0
  2. app.py +300 -63
  3. requirements.txt +10 -8
__pycache__/app.cpython-313.pyc ADDED
Binary file (21.2 kB). View file
 
app.py CHANGED
@@ -1,6 +1,6 @@
1
- from fastapi import FastAPI, UploadFile, Form, HTTPException, Depends, status
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import JSONResponse
4
  import pandas as pd
5
  import numpy as np
6
  from sklearn.naive_bayes import CategoricalNB
@@ -11,6 +11,9 @@ import json
11
  import io
12
  from typing import Dict, List, Optional, Any
13
  from pydantic import BaseModel, Field
 
 
 
14
 
15
  app = FastAPI(
16
  title="Categorical Naive Bayes API",
@@ -68,161 +71,395 @@ async def health_check():
68
  return {"status": "healthy"}
69
 
70
  @app.post("/api/upload", tags=["Data"], summary="Upload CSV File", response_model=UploadResponse, status_code=status.HTTP_200_OK)
71
- async def upload_csv(file: UploadFile) -> UploadResponse:
 
 
72
  """Upload a CSV file and get metadata about its columns."""
73
- if not file.filename.endswith('.csv'):
74
- raise HTTPException(status_code=400, detail="Only CSV files are allowed")
 
 
 
 
75
  try:
76
  contents = await file.read()
77
- df = pd.read_csv(io.StringIO(contents.decode()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  columns = df.columns.tolist()
79
  column_types = {col: str(df[col].dtype) for col in columns}
80
- unique_values = {col: df[col].unique().tolist() for col in columns}
 
 
 
 
 
 
 
 
 
 
 
81
  for col, values in unique_values.items():
82
  unique_values[col] = [v.item() if isinstance(v, np.generic) else v for v in values]
 
83
  return UploadResponse(
84
- message="File uploaded successfully",
85
  columns=columns,
86
  column_types=column_types,
87
  unique_values=unique_values,
88
  row_count=len(df)
89
  )
 
 
 
90
  except Exception as e:
91
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
92
 
93
  @app.post("/api/train", tags=["Model"], summary="Train Model", response_model=TrainResponse, status_code=status.HTTP_200_OK)
94
  async def train_model(
95
- file: UploadFile,
96
- options: str = Form(...),
97
  state: ModelState = Depends(get_model_state)
98
  ) -> TrainResponse:
99
- """Train a Categorical Naive Bayes model on the uploaded CSV."""
 
 
 
 
 
100
  try:
101
- train_options = json.loads(options)
102
- target_column = train_options["target_column"]
103
- feature_columns = train_options["feature_columns"]
104
  contents = await file.read()
105
- df = pd.read_csv(io.StringIO(contents.decode()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  X = pd.DataFrame()
107
  state.feature_encoders = {}
108
- for column in feature_columns:
 
 
 
 
 
 
 
109
  encoder = LabelEncoder()
110
  X[column] = encoder.fit_transform(df[column])
111
  state.feature_encoders[column] = encoder
 
 
 
 
 
 
 
 
112
  state.target_encoder = LabelEncoder()
113
- y = state.target_encoder.fit_transform(df[target_column])
 
 
114
  X_train, X_test, y_train, y_test = train_test_split(
115
- X, y, test_size=0.2, random_state=42
116
  )
 
 
117
  state.model = CategoricalNB()
118
  state.model.fit(X_train, y_train)
119
  accuracy = float(state.model.score(X_test, y_test))
120
  state.X_test = X_test
121
  state.y_test = y_test
 
122
  return TrainResponse(
123
  message="Model trained successfully",
124
  accuracy=accuracy,
125
  target_classes=list(state.target_encoder.classes_)
126
  )
 
 
127
  except Exception as e:
128
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
129
 
130
  @app.post("/api/predict", tags=["Model"], summary="Predict", response_model=PredictResponse, status_code=status.HTTP_200_OK)
131
  async def predict(
132
  features: PredictionFeatures,
133
  state: ModelState = Depends(get_model_state)
134
  ) -> PredictResponse:
135
- """Predict the target class for given features using the trained model."""
 
 
 
 
136
  if state.model is None:
137
- raise HTTPException(status_code=400, detail="Model not trained yet")
 
 
 
 
138
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  encoded_features = {}
140
  for column, value in features.features.items():
141
- if column in state.feature_encoders:
142
  encoded_features[column] = state.feature_encoders[column].transform([value])[0]
 
 
 
 
 
 
 
 
143
  X = pd.DataFrame([encoded_features])
144
  prediction = state.model.predict(X)
145
  prediction_proba = state.model.predict_proba(X)
146
  predicted_class = state.target_encoder.inverse_transform(prediction)[0]
 
 
147
  class_probabilities = {
148
  state.target_encoder.inverse_transform([i])[0]: float(prob)
149
  for i, prob in enumerate(prediction_proba[0])
150
  }
 
151
  return PredictResponse(
152
  prediction=predicted_class,
153
  probabilities=class_probabilities
154
  )
 
 
155
  except Exception as e:
156
- raise HTTPException(status_code=500, detail=str(e))
157
-
158
- from fastapi.responses import StreamingResponse
159
- import matplotlib.pyplot as plt
160
 
161
- @app.get("/api/plot/confusion-matrix", tags=["Model"], summary="Confusion Matrix Plot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  async def plot_confusion_matrix(state: ModelState = Depends(get_model_state)):
163
  """Return a PNG image of the confusion matrix for the last test set."""
164
  if state.model is None or state.X_test is None or state.y_test is None:
165
- raise HTTPException(status_code=400, detail="Model not trained or no test data available.")
166
- y_pred = state.model.predict(state.X_test)
167
- cm = confusion_matrix(state.y_test, y_pred)
168
- fig, ax = plt.subplots(figsize=(5, 4))
169
- cax = ax.matshow(cm, cmap=plt.cm.Blues)
170
- plt.title('Confusion Matrix')
171
- plt.xlabel('Predicted')
172
- plt.ylabel('Actual')
173
- plt.colorbar(cax)
174
- classes = state.target_encoder.classes_ if state.target_encoder else []
175
- ax.set_xticks(np.arange(len(classes)))
176
- ax.set_yticks(np.arange(len(classes)))
177
- ax.set_xticklabels(classes, rotation=45, ha="left")
178
- ax.set_yticklabels(classes)
179
- for (i, j), z in np.ndenumerate(cm):
180
- ax.text(j, i, str(z), ha='center', va='center', color='red')
181
- plt.tight_layout()
182
- buf = io.BytesIO()
183
- plt.savefig(buf, format='png')
184
- plt.close(fig)
185
- buf.seek(0)
186
- return StreamingResponse(buf, media_type="image/png")
187
-
188
- @app.get("/api/plot/feature-log-prob", tags=["Model"], summary="Feature Log Probability Heatmap")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  async def plot_feature_log_prob(state: ModelState = Depends(get_model_state)):
190
  """Return a PNG heatmap of feature log probabilities for each class."""
191
  if state.model is None or state.target_encoder is None:
192
- raise HTTPException(status_code=400, detail="Model not trained.")
 
 
 
 
193
  try:
194
- import matplotlib.pyplot as plt
195
- import seaborn as sns
196
  feature_names = list(state.feature_encoders.keys())
197
  class_names = list(state.target_encoder.classes_)
198
- # CategoricalNB: feature_log_prob_ shape (n_classes, n_features, n_categories)
199
- # We'll plot for each feature, the log prob for each class and each value
200
- fig, axes = plt.subplots(len(feature_names), 1, figsize=(8, 4 * len(feature_names)))
 
201
  if len(feature_names) == 1:
202
  axes = [axes]
 
203
  for idx, feature in enumerate(feature_names):
204
  encoder = state.feature_encoders[feature]
205
  categories = encoder.classes_
206
  data = []
 
207
  for class_idx, class_name in enumerate(class_names):
208
  # For each class, get the log prob for each value of this feature
209
  log_probs = state.model.feature_log_prob_[class_idx, idx, :]
210
  data.append(log_probs)
 
211
  data = np.array(data)
212
  ax = axes[idx]
213
- sns.heatmap(data, annot=True, fmt=".2f", cmap="Blues", xticklabels=categories, yticklabels=class_names, ax=ax)
 
 
 
 
 
 
 
 
 
 
 
214
  ax.set_title(f'Log Probabilities for Feature: {feature}')
215
  ax.set_xlabel('Feature Value')
216
  ax.set_ylabel('Class')
 
217
  plt.tight_layout()
 
 
218
  buf = io.BytesIO()
219
- plt.savefig(buf, format='png')
220
  plt.close(fig)
221
  buf.seek(0)
222
- return StreamingResponse(buf, media_type="image/png")
 
 
 
 
223
  except Exception as e:
224
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
225
 
226
  if __name__ == "__main__":
227
  import uvicorn
228
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, Form, HTTPException, Depends, status, File
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse, StreamingResponse
4
  import pandas as pd
5
  import numpy as np
6
  from sklearn.naive_bayes import CategoricalNB
 
11
  import io
12
  from typing import Dict, List, Optional, Any
13
  from pydantic import BaseModel, Field
14
+ import matplotlib.pyplot as plt
15
+ import seaborn as sns
16
+ from fastapi.encoders import jsonable_encoder
17
 
18
  app = FastAPI(
19
  title="Categorical Naive Bayes API",
 
71
  return {"status": "healthy"}
72
 
73
  @app.post("/api/upload", tags=["Data"], summary="Upload CSV File", response_model=UploadResponse, status_code=status.HTTP_200_OK)
74
+ async def upload_csv(
75
+ file: UploadFile = File(..., description="CSV file to upload")
76
+ ) -> UploadResponse:
77
  """Upload a CSV file and get metadata about its columns."""
78
+ if not file.filename or not file.filename.lower().endswith('.csv'):
79
+ raise HTTPException(
80
+ status_code=status.HTTP_400_BAD_REQUEST,
81
+ detail="Only CSV files are allowed"
82
+ )
83
+
84
  try:
85
  contents = await file.read()
86
+ # Check if the file content is valid
87
+ if len(contents) == 0:
88
+ raise HTTPException(
89
+ status_code=status.HTTP_400_BAD_REQUEST,
90
+ detail="Uploaded file is empty"
91
+ )
92
+
93
+ # Try to parse the CSV
94
+ try:
95
+ df = pd.read_csv(io.StringIO(contents.decode('utf-8')))
96
+ except UnicodeDecodeError:
97
+ # Try another encoding if UTF-8 fails
98
+ try:
99
+ df = pd.read_csv(io.StringIO(contents.decode('latin-1')))
100
+ except Exception:
101
+ raise HTTPException(
102
+ status_code=status.HTTP_400_BAD_REQUEST,
103
+ detail="Unable to decode CSV file. Please ensure it's properly formatted."
104
+ )
105
+ except Exception as e:
106
+ raise HTTPException(
107
+ status_code=status.HTTP_400_BAD_REQUEST,
108
+ detail=f"Error parsing CSV: {str(e)}"
109
+ )
110
+
111
+ if df.empty:
112
+ raise HTTPException(
113
+ status_code=status.HTTP_400_BAD_REQUEST,
114
+ detail="CSV file contains no data"
115
+ )
116
+
117
+ # Process the data
118
  columns = df.columns.tolist()
119
  column_types = {col: str(df[col].dtype) for col in columns}
120
+
121
+ # Limit the number of unique values to prevent excessive response sizes
122
+ unique_values = {}
123
+ for col in columns:
124
+ unique_vals = df[col].unique().tolist()
125
+ # Limit to 100 values max to prevent excessive response size
126
+ if len(unique_vals) > 100:
127
+ unique_values[col] = unique_vals[:100] + ["... (truncated)"]
128
+ else:
129
+ unique_values[col] = unique_vals
130
+
131
+ # Convert NumPy objects to Python native types
132
  for col, values in unique_values.items():
133
  unique_values[col] = [v.item() if isinstance(v, np.generic) else v for v in values]
134
+
135
  return UploadResponse(
136
+ message="File uploaded and processed successfully",
137
  columns=columns,
138
  column_types=column_types,
139
  unique_values=unique_values,
140
  row_count=len(df)
141
  )
142
+ except HTTPException:
143
+ # Re-raise HTTP exceptions
144
+ raise
145
  except Exception as e:
146
+ raise HTTPException(
147
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
148
+ detail=f"An unexpected error occurred: {str(e)}"
149
+ )
150
 
151
  @app.post("/api/train", tags=["Model"], summary="Train Model", response_model=TrainResponse, status_code=status.HTTP_200_OK)
152
  async def train_model(
153
+ file: UploadFile = File(..., description="CSV file to train on"),
154
+ options: TrainOptions = Depends(),
155
  state: ModelState = Depends(get_model_state)
156
  ) -> TrainResponse:
157
+ """Train a Categorical Naive Bayes model on the uploaded CSV.
158
+
159
+ Parameters:
160
+ - file: CSV file with the training data
161
+ - options: Training options specifying target column and feature columns
162
+ """
163
  try:
 
 
 
164
  contents = await file.read()
165
+
166
+ try:
167
+ df = pd.read_csv(io.StringIO(contents.decode('utf-8')))
168
+ except UnicodeDecodeError:
169
+ df = pd.read_csv(io.StringIO(contents.decode('latin-1')))
170
+
171
+ # Validate columns exist in the DataFrame
172
+ missing_columns = [col for col in [options.target_column] + options.feature_columns
173
+ if col not in df.columns]
174
+ if missing_columns:
175
+ raise HTTPException(
176
+ status_code=status.HTTP_400_BAD_REQUEST,
177
+ detail=f"Columns not found in CSV: {', '.join(missing_columns)}"
178
+ )
179
+
180
+ # Initialize data structures
181
  X = pd.DataFrame()
182
  state.feature_encoders = {}
183
+
184
+ # Encode features
185
+ for column in options.feature_columns:
186
+ if df[column].isna().any():
187
+ raise HTTPException(
188
+ status_code=status.HTTP_400_BAD_REQUEST,
189
+ detail=f"Column '{column}' contains missing values. Please preprocess your data."
190
+ )
191
  encoder = LabelEncoder()
192
  X[column] = encoder.fit_transform(df[column])
193
  state.feature_encoders[column] = encoder
194
+
195
+ # Encode target
196
+ if df[options.target_column].isna().any():
197
+ raise HTTPException(
198
+ status_code=status.HTTP_400_BAD_REQUEST,
199
+ detail=f"Target column '{options.target_column}' contains missing values."
200
+ )
201
+
202
  state.target_encoder = LabelEncoder()
203
+ y = state.target_encoder.fit_transform(df[options.target_column])
204
+
205
+ # Train/test split
206
  X_train, X_test, y_train, y_test = train_test_split(
207
+ X, y, test_size=0.2, random_state=42, stratify=y if len(np.unique(y)) > 1 else None
208
  )
209
+
210
+ # Train the model
211
  state.model = CategoricalNB()
212
  state.model.fit(X_train, y_train)
213
  accuracy = float(state.model.score(X_test, y_test))
214
  state.X_test = X_test
215
  state.y_test = y_test
216
+
217
  return TrainResponse(
218
  message="Model trained successfully",
219
  accuracy=accuracy,
220
  target_classes=list(state.target_encoder.classes_)
221
  )
222
+ except HTTPException:
223
+ raise
224
  except Exception as e:
225
+ raise HTTPException(
226
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
227
+ detail=f"Model training failed: {str(e)}"
228
+ )
229
 
230
  @app.post("/api/predict", tags=["Model"], summary="Predict", response_model=PredictResponse, status_code=status.HTTP_200_OK)
231
  async def predict(
232
  features: PredictionFeatures,
233
  state: ModelState = Depends(get_model_state)
234
  ) -> PredictResponse:
235
+ """Predict the target class for given features using the trained model.
236
+
237
+ Parameters:
238
+ - features: Dictionary of feature values for prediction
239
+ """
240
  if state.model is None:
241
+ raise HTTPException(
242
+ status_code=status.HTTP_400_BAD_REQUEST,
243
+ detail="Model not trained yet. Please train a model before making predictions."
244
+ )
245
+
246
  try:
247
+ # Validate that all required features are provided
248
+ required_features = set(state.feature_encoders.keys())
249
+ provided_features = set(features.features.keys())
250
+
251
+ missing_features = required_features - provided_features
252
+ if missing_features:
253
+ raise HTTPException(
254
+ status_code=status.HTTP_400_BAD_REQUEST,
255
+ detail=f"Missing required features: {', '.join(missing_features)}"
256
+ )
257
+
258
+ # Validate extra features
259
+ extra_features = provided_features - required_features
260
+ if extra_features:
261
+ raise HTTPException(
262
+ status_code=status.HTTP_400_BAD_REQUEST,
263
+ detail=f"Unexpected features provided: {', '.join(extra_features)}"
264
+ )
265
+
266
+ # Encode features
267
  encoded_features = {}
268
  for column, value in features.features.items():
269
+ try:
270
  encoded_features[column] = state.feature_encoders[column].transform([value])[0]
271
+ except ValueError:
272
+ # Handle unknown category values
273
+ raise HTTPException(
274
+ status_code=status.HTTP_400_BAD_REQUEST,
275
+ detail=f"Unknown value '{value}' for feature '{column}'. Allowed values: {', '.join(map(str, state.feature_encoders[column].classes_))}"
276
+ )
277
+
278
+ # Make prediction
279
  X = pd.DataFrame([encoded_features])
280
  prediction = state.model.predict(X)
281
  prediction_proba = state.model.predict_proba(X)
282
  predicted_class = state.target_encoder.inverse_transform(prediction)[0]
283
+
284
+ # Generate probabilities
285
  class_probabilities = {
286
  state.target_encoder.inverse_transform([i])[0]: float(prob)
287
  for i, prob in enumerate(prediction_proba[0])
288
  }
289
+
290
  return PredictResponse(
291
  prediction=predicted_class,
292
  probabilities=class_probabilities
293
  )
294
+ except HTTPException:
295
+ raise
296
  except Exception as e:
297
+ raise HTTPException(
298
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
299
+ detail=f"Prediction failed: {str(e)}"
300
+ )
301
 
302
+ @app.get(
303
+ "/api/plot/confusion-matrix",
304
+ tags=["Model"],
305
+ summary="Confusion Matrix Plot",
306
+ response_class=StreamingResponse,
307
+ responses={
308
+ 200: {
309
+ "content": {"image/png": {}},
310
+ "description": "PNG image of confusion matrix"
311
+ },
312
+ 400: {
313
+ "description": "Model not trained or no test data available"
314
+ }
315
+ }
316
+ )
317
  async def plot_confusion_matrix(state: ModelState = Depends(get_model_state)):
318
  """Return a PNG image of the confusion matrix for the last test set."""
319
  if state.model is None or state.X_test is None or state.y_test is None:
320
+ raise HTTPException(
321
+ status_code=status.HTTP_400_BAD_REQUEST,
322
+ detail="Model not trained or no test data available."
323
+ )
324
+
325
+ try:
326
+ y_pred = state.model.predict(state.X_test)
327
+ cm = confusion_matrix(state.y_test, y_pred)
328
+
329
+ # Create plot
330
+ fig, ax = plt.subplots(figsize=(7, 6))
331
+ cax = ax.matshow(cm, cmap=plt.cm.Blues)
332
+ plt.title('Confusion Matrix')
333
+ plt.xlabel('Predicted')
334
+ plt.ylabel('Actual')
335
+ plt.colorbar(cax)
336
+
337
+ # Add labels
338
+ classes = state.target_encoder.classes_ if state.target_encoder else []
339
+ ax.set_xticks(np.arange(len(classes)))
340
+ ax.set_yticks(np.arange(len(classes)))
341
+ ax.set_xticklabels(classes, rotation=45, ha="left")
342
+ ax.set_yticklabels(classes)
343
+
344
+ # Add numbers to the plot
345
+ for (i, j), z in np.ndenumerate(cm):
346
+ ax.text(j, i, str(z), ha='center', va='center',
347
+ color='white' if cm[i, j] > cm.max() / 2 else 'black')
348
+
349
+ plt.tight_layout()
350
+
351
+ # Save to buffer
352
+ buf = io.BytesIO()
353
+ plt.savefig(buf, format='png', dpi=150)
354
+ plt.close(fig)
355
+ buf.seek(0)
356
+
357
+ # Create response with cache control headers
358
+ response = StreamingResponse(buf, media_type="image/png")
359
+ response.headers["Cache-Control"] = "max-age=3600" # Cache for 1 hour
360
+ return response
361
+ except Exception as e:
362
+ raise HTTPException(
363
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
364
+ detail=f"Failed to generate confusion matrix: {str(e)}"
365
+ )
366
+
367
+ @app.get(
368
+ "/api/plot/feature-log-prob",
369
+ tags=["Model"],
370
+ summary="Feature Log Probability Heatmap",
371
+ response_class=StreamingResponse,
372
+ responses={
373
+ 200: {
374
+ "content": {"image/png": {}},
375
+ "description": "PNG heatmap of feature log probabilities"
376
+ },
377
+ 400: {
378
+ "description": "Model not trained"
379
+ }
380
+ }
381
+ )
382
  async def plot_feature_log_prob(state: ModelState = Depends(get_model_state)):
383
  """Return a PNG heatmap of feature log probabilities for each class."""
384
  if state.model is None or state.target_encoder is None:
385
+ raise HTTPException(
386
+ status_code=status.HTTP_400_BAD_REQUEST,
387
+ detail="Model not trained."
388
+ )
389
+
390
  try:
 
 
391
  feature_names = list(state.feature_encoders.keys())
392
  class_names = list(state.target_encoder.classes_)
393
+
394
+ # Calculate plot size based on data
395
+ fig_height = max(4, 2 * len(feature_names))
396
+ fig, axes = plt.subplots(len(feature_names), 1, figsize=(10, fig_height))
397
  if len(feature_names) == 1:
398
  axes = [axes]
399
+
400
  for idx, feature in enumerate(feature_names):
401
  encoder = state.feature_encoders[feature]
402
  categories = encoder.classes_
403
  data = []
404
+
405
  for class_idx, class_name in enumerate(class_names):
406
  # For each class, get the log prob for each value of this feature
407
  log_probs = state.model.feature_log_prob_[class_idx, idx, :]
408
  data.append(log_probs)
409
+
410
  data = np.array(data)
411
  ax = axes[idx]
412
+
413
+ # Create heatmap
414
+ sns.heatmap(
415
+ data,
416
+ annot=True,
417
+ fmt=".2f",
418
+ cmap="Blues",
419
+ xticklabels=categories,
420
+ yticklabels=class_names,
421
+ ax=ax
422
+ )
423
+
424
  ax.set_title(f'Log Probabilities for Feature: {feature}')
425
  ax.set_xlabel('Feature Value')
426
  ax.set_ylabel('Class')
427
+
428
  plt.tight_layout()
429
+
430
+ # Save to buffer
431
  buf = io.BytesIO()
432
+ plt.savefig(buf, format='png', dpi=150)
433
  plt.close(fig)
434
  buf.seek(0)
435
+
436
+ # Create response with cache control headers
437
+ response = StreamingResponse(buf, media_type="image/png")
438
+ response.headers["Cache-Control"] = "max-age=3600" # Cache for 1 hour
439
+ return response
440
  except Exception as e:
441
+ raise HTTPException(
442
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
443
+ detail=f"Failed to generate feature log probability plot: {str(e)}"
444
+ )
445
 
446
  if __name__ == "__main__":
447
  import uvicorn
448
+ import os
449
+
450
+ # Get port from environment variable or default to 7860 (for HF Spaces)
451
+ port = int(os.environ.get("PORT", 7860))
452
+
453
+ # Configure logging for better visibility
454
+ log_config = uvicorn.config.LOGGING_CONFIG
455
+ log_config["formatters"]["access"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s"
456
+ log_config["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s"
457
+
458
+ uvicorn.run(
459
+ "app:app",
460
+ host="0.0.0.0",
461
+ port=port,
462
+ log_level="info",
463
+ reload=True,
464
+ log_config=log_config
465
+ )
requirements.txt CHANGED
@@ -1,8 +1,10 @@
1
- fastapi
2
- uvicorn
3
- python-multipart
4
- pandas
5
- scikit-learn
6
- numpy
7
- matplotlib
8
- gunicorn
 
 
 
1
+ fastapi>=0.104.0
2
+ uvicorn>=0.23.2
3
+ python-multipart>=0.0.6
4
+ pandas>=2.0.0
5
+ scikit-learn>=1.3.0
6
+ numpy>=1.24.0
7
+ matplotlib>=3.7.0
8
+ seaborn>=0.12.2
9
+ gunicorn>=21.2.0
10
+ pydantic>=2.0.0