AbdullahImran commited on
Commit
bfcfd26
Β·
verified Β·
1 Parent(s): a235e07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -193
app.py CHANGED
@@ -31,148 +31,150 @@ def process_csv_file(file):
31
  gr.Warning(f"Error reading CSV file: {str(e)}")
32
  return None
33
 
34
- def classify_fn(file):
35
- """Bankruptcy classification from CSV file"""
36
  if file is None:
37
- return "Please upload a CSV file", None
38
 
39
  df = process_csv_file(file)
40
  if df is None:
41
- return "Error processing file", None
42
 
43
  try:
44
- # Use all rows in the CSV for prediction
45
- preds = xgb_clf.predict(df)
46
- probs = xgb_clf.predict_proba(df)
47
 
48
- # Create visualization
49
- fig, ax = plt.subplots(figsize=(10, 6), facecolor='#1f1f1f')
50
- ax.set_facecolor('#1f1f1f')
 
 
51
 
52
- if len(preds) == 1:
53
- # Single company prediction
54
- bars = ax.bar(['No Bankruptcy', 'Bankruptcy'], probs[0],
55
- color=['#4CAF50', '#F44336'], alpha=0.8)
56
- ax.set_ylim(0, 1)
57
- ax.set_title('Bankruptcy Probability', color='white', fontsize=14)
58
- ax.set_ylabel('Probability', color='white')
59
- result_text = f"Prediction: {'Bankruptcy Risk' if preds[0] == 1 else 'No Bankruptcy Risk'}\nConfidence: {max(probs[0]):.2%}"
60
- else:
61
- # Multiple companies
62
- bankruptcy_count = np.sum(preds)
63
- safe_count = len(preds) - bankruptcy_count
64
- bars = ax.bar(['Safe Companies', 'At Risk Companies'],
65
- [safe_count, bankruptcy_count],
66
- color=['#4CAF50', '#F44336'], alpha=0.8)
67
- ax.set_title(f'Bankruptcy Analysis for {len(preds)} Companies', color='white', fontsize=14)
68
- ax.set_ylabel('Number of Companies', color='white')
69
- result_text = f"Total Companies: {len(preds)}\nSafe: {safe_count}\nAt Risk: {bankruptcy_count}"
70
-
71
- ax.tick_params(colors='white')
72
- ax.spines['bottom'].set_color('white')
73
- ax.spines['left'].set_color('white')
74
- ax.spines['top'].set_visible(False)
75
- ax.spines['right'].set_visible(False)
76
-
77
- plt.tight_layout()
78
- return result_text, fig
79
 
80
- except Exception as e:
81
- return f"Error in prediction: {str(e)}", None
82
-
83
- def regress_fn(file):
84
- """Anomaly detection from CSV file"""
85
- if file is None:
86
- return "Please upload a CSV file", None
87
-
88
- df = process_csv_file(file)
89
- if df is None:
90
- return "Error processing file", None
91
-
92
- try:
93
- preds = xgb_reg.predict(df)
94
 
95
- # Create visualization
96
- fig, ax = plt.subplots(figsize=(10, 6), facecolor='#1f1f1f')
97
- ax.set_facecolor('#1f1f1f')
98
 
99
- sns.histplot(preds, bins=20, kde=True, ax=ax, color='#00BCD4', alpha=0.7)
100
- ax.set_title('Anomaly Score Distribution', color='white', fontsize=14)
101
- ax.set_xlabel('Anomaly Score', color='white')
102
- ax.set_ylabel('Frequency', color='white')
103
- ax.tick_params(colors='white')
104
- ax.spines['bottom'].set_color('white')
105
- ax.spines['left'].set_color('white')
106
- ax.spines['top'].set_visible(False)
107
- ax.spines['right'].set_visible(False)
 
 
 
 
 
 
 
108
 
 
 
 
 
 
109
  plt.tight_layout()
110
 
111
- # Summary statistics
112
- avg_score = np.mean(preds)
113
- high_risk_count = np.sum(preds > np.percentile(preds, 75))
114
- result_text = f"Average Anomaly Score: {avg_score:.3f}\nHigh Risk Companies: {high_risk_count}/{len(preds)}\nScore Range: {np.min(preds):.3f} - {np.max(preds):.3f}"
115
-
116
- return result_text, fig
117
-
118
- except Exception as e:
119
- return f"Error in prediction: {str(e)}", None
120
-
121
- def lstm_fn(file):
122
- """LSTM revenue forecasting from CSV file"""
123
- if file is None:
124
- return "Please upload a CSV file", None
125
-
126
- df = process_csv_file(file)
127
- if df is None:
128
- return "Error processing file", None
129
-
130
- try:
131
- # Expect CSV with revenue columns or a single row with 10 revenue values
132
- if df.shape[1] < 10:
133
- return "CSV must contain at least 10 revenue columns for quarterly data", None
134
 
135
- # Take first row and first 10 columns as revenue sequence
136
- vals = df.iloc[0, :10].values.astype(float).reshape(1, -1)
 
137
 
138
- # Scale and predict
139
- vals_s = scaler_X.transform(vals).reshape((1, vals.shape[1], 1))
140
- pred_s = lstm_model.predict(vals_s)
141
- pred = scaler_y.inverse_transform(pred_s)[0, 0]
142
-
143
- # Create visualization
144
- fig, ax = plt.subplots(figsize=(12, 6), facecolor='#1f1f1f')
145
- ax.set_facecolor('#1f1f1f')
146
-
147
- quarters = [f'Q{i+1}' for i in range(10)]
148
- ax.plot(quarters, vals.flatten(), marker='o', linewidth=2,
149
- markersize=8, color='#2196F3', label='Historical Revenue')
150
- ax.plot('Q11', pred, marker='X', markersize=15, color='#FF5722',
151
- label=f'Predicted Q11: ${pred:,.0f}')
152
 
153
- ax.set_xlabel('Quarter', color='white')
154
- ax.set_ylabel('Revenue ($)', color='white')
155
- ax.set_title('Revenue Forecast - Next Quarter Prediction', color='white', fontsize=14)
156
- ax.legend(facecolor='#2f2f2f', edgecolor='white', labelcolor='white')
157
- ax.tick_params(colors='white')
158
- ax.spines['bottom'].set_color('white')
159
- ax.spines['left'].set_color('white')
160
- ax.spines['top'].set_visible(False)
161
- ax.spines['right'].set_visible(False)
162
- ax.grid(True, alpha=0.3, color='white')
163
 
164
- plt.xticks(rotation=45)
165
- plt.tight_layout()
 
 
166
 
167
- # Calculate growth rate
168
- last_revenue = vals.flatten()[-1]
169
- growth_rate = ((pred - last_revenue) / last_revenue) * 100
170
- result_text = f"Predicted Q11 Revenue: ${pred:,.0f}\nGrowth from Q10: {growth_rate:+.1f}%"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- return result_text, fig
173
 
174
  except Exception as e:
175
- return f"Error in prediction: {str(e)}", None
 
176
 
177
  # Custom CSS for proper dark mode support
178
  custom_css = """
@@ -223,23 +225,6 @@ custom_css = """
223
  color: #ffffff !important;
224
  }
225
 
226
- /* Tab styling */
227
- .gr-tab-nav {
228
- background-color: #2d2d2d !important;
229
- border-bottom: 1px solid #404040 !important;
230
- }
231
-
232
- .gr-tab-nav button {
233
- background-color: transparent !important;
234
- color: #ffffff !important;
235
- border: none !important;
236
- }
237
-
238
- .gr-tab-nav button.selected {
239
- background-color: #0066cc !important;
240
- color: white !important;
241
- }
242
-
243
  /* Text and markdown */
244
  .gr-markdown {
245
  color: #ffffff !important;
@@ -259,70 +244,66 @@ custom_css = """
259
  with gr.Blocks(css=custom_css, theme=gr.themes.Base(), title="TriCast AI") as demo:
260
  gr.Markdown("""
261
  # πŸš€ TriCast AI
262
- ### Advanced Financial Intelligence Platform
263
- Upload your company's financial data as a CSV file to get comprehensive AI-powered insights across three key areas.
264
  """)
265
 
266
  gr.Markdown("""
267
- **πŸ“ CSV File Format Guidelines:**
268
- - **Bankruptcy & Anomaly Detection**: Include financial metrics as columns (revenue, debt, assets, etc.)
269
- - **Revenue Forecasting**: First 10 columns should contain quarterly revenue data
270
- - Each row represents one company's data
271
- """)
 
 
272
 
273
- with gr.Tab("🏦 Bankruptcy Risk Assessment"):
274
- gr.Markdown("**Upload CSV with company financial data to assess bankruptcy risk**")
275
- with gr.Row():
276
- with gr.Column():
277
- file1 = gr.File(label="Upload CSV File", file_types=[".csv"])
278
- classify_btn = gr.Button("πŸ” Analyze Bankruptcy Risk", variant="primary")
279
- with gr.Column():
280
- out1 = gr.Textbox(label="Analysis Results", lines=4)
281
- plt1 = gr.Plot(label="Risk Visualization")
282
- classify_btn.click(fn=classify_fn, inputs=file1, outputs=[out1, plt1])
283
 
284
- with gr.Tab("πŸ“Š Anomaly Detection"):
285
- gr.Markdown("**Upload CSV with company financial data to detect anomalies**")
286
- with gr.Row():
287
- with gr.Column():
288
- file2 = gr.File(label="Upload CSV File", file_types=[".csv"])
289
- regress_btn = gr.Button("πŸ”Ž Detect Anomalies", variant="primary")
290
- with gr.Column():
291
- out2 = gr.Textbox(label="Anomaly Analysis", lines=4)
292
- plt2 = gr.Plot(label="Score Distribution")
293
- regress_btn.click(fn=regress_fn, inputs=file2, outputs=[out2, plt2])
 
 
294
 
295
- with gr.Tab("πŸ“ˆ Revenue Forecasting"):
296
- gr.Markdown("**Upload CSV with quarterly revenue data (10 quarters) to forecast next quarter**")
297
- with gr.Row():
298
- with gr.Column():
299
- file3 = gr.File(label="Upload CSV File", file_types=[".csv"])
300
- forecast_btn = gr.Button("πŸ“Š Forecast Revenue", variant="primary")
301
- with gr.Column():
302
- out3 = gr.Textbox(label="Forecast Results", lines=4)
303
- plt3 = gr.Plot(label="Revenue Trend & Prediction")
304
- forecast_btn.click(fn=lstm_fn, inputs=file3, outputs=[out3, plt3])
305
 
306
- with gr.Tab("πŸ“‹ Sample Data Format"):
307
- gr.Markdown("""
308
- ### Sample CSV Formats:
 
 
 
 
 
 
 
 
309
 
310
- **For Bankruptcy & Anomaly Detection:**
311
- ```
312
- company_name,total_assets,total_liabilities,revenue,debt_ratio,current_ratio
313
- Company A,1000000,500000,800000,0.5,2.1
314
- Company B,2000000,1800000,600000,0.9,0.8
315
- ```
 
 
316
 
317
- **For Revenue Forecasting:**
318
- ```
319
- q1_revenue,q2_revenue,q3_revenue,q4_revenue,q5_revenue,q6_revenue,q7_revenue,q8_revenue,q9_revenue,q10_revenue
320
- 100000,120000,110000,130000,125000,140000,135000,150000,145000,160000
321
- ```
322
- """)
323
-
324
- gr.Markdown("---")
325
- gr.Markdown("*TriCast AI - Powered by Advanced Machine Learning | Industry, Innovation and Infrastructure*")
326
 
327
  if __name__ == "__main__":
328
- demo.launch()
 
31
  gr.Warning(f"Error reading CSV file: {str(e)}")
32
  return None
33
 
34
+ def run_all_models(file):
35
+ """Run all three models on the uploaded CSV file"""
36
  if file is None:
37
+ return "Please upload a CSV file", None, None, None, None, None
38
 
39
  df = process_csv_file(file)
40
  if df is None:
41
+ return "Error processing file", None, None, None, None, None
42
 
43
  try:
44
+ # Prepare data for models (assuming same feature set as training)
45
+ model_features = df.copy()
 
46
 
47
+ # Remove non-feature columns if they exist
48
+ cols_to_remove = ['Id', 'anomaly_score', 'risk_flag']
49
+ for col in cols_to_remove:
50
+ if col in model_features.columns:
51
+ model_features = model_features.drop(col, axis=1)
52
 
53
+ # Handle missing values
54
+ model_features = model_features.fillna(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ # 1. BANKRUPTCY CLASSIFICATION
57
+ bankruptcy_preds = xgb_clf.predict(model_features)
58
+ bankruptcy_probs = xgb_clf.predict_proba(model_features)
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # Create bankruptcy visualization
61
+ fig1, ax1 = plt.subplots(figsize=(10, 6), facecolor='#1f1f1f')
62
+ ax1.set_facecolor('#1f1f1f')
63
 
64
+ if len(bankruptcy_preds) == 1:
65
+ bars = ax1.bar(['No Bankruptcy', 'Bankruptcy'], bankruptcy_probs[0],
66
+ color=['#4CAF50', '#F44336'], alpha=0.8)
67
+ ax1.set_ylim(0, 1)
68
+ ax1.set_title('Bankruptcy Risk Probability', color='white', fontsize=14)
69
+ ax1.set_ylabel('Probability', color='white')
70
+ bankruptcy_result = f"Prediction: {'High Bankruptcy Risk' if bankruptcy_preds[0] == 1 else 'Low Bankruptcy Risk'}\nConfidence: {max(bankruptcy_probs[0]):.2%}"
71
+ else:
72
+ bankruptcy_count = np.sum(bankruptcy_preds)
73
+ safe_count = len(bankruptcy_preds) - bankruptcy_count
74
+ bars = ax1.bar(['Safe Companies', 'At Risk Companies'],
75
+ [safe_count, bankruptcy_count],
76
+ color=['#4CAF50', '#F44336'], alpha=0.8)
77
+ ax1.set_title(f'Bankruptcy Analysis for {len(bankruptcy_preds)} Companies', color='white', fontsize=14)
78
+ ax1.set_ylabel('Number of Companies', color='white')
79
+ bankruptcy_result = f"Total Companies: {len(bankruptcy_preds)}\nSafe: {safe_count}\nAt Risk: {bankruptcy_count}"
80
 
81
+ ax1.tick_params(colors='white')
82
+ ax1.spines['bottom'].set_color('white')
83
+ ax1.spines['left'].set_color('white')
84
+ ax1.spines['top'].set_visible(False)
85
+ ax1.spines['right'].set_visible(False)
86
  plt.tight_layout()
87
 
88
+ # 2. ANOMALY DETECTION
89
+ anomaly_preds = xgb_reg.predict(model_features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # Create anomaly visualization
92
+ fig2, ax2 = plt.subplots(figsize=(10, 6), facecolor='#1f1f1f')
93
+ ax2.set_facecolor('#1f1f1f')
94
 
95
+ sns.histplot(anomaly_preds, bins=20, kde=True, ax=ax2, color='#00BCD4', alpha=0.7)
96
+ ax2.set_title('Anomaly Score Distribution', color='white', fontsize=14)
97
+ ax2.set_xlabel('Anomaly Score', color='white')
98
+ ax2.set_ylabel('Frequency', color='white')
99
+ ax2.tick_params(colors='white')
100
+ ax2.spines['bottom'].set_color('white')
101
+ ax2.spines['left'].set_color('white')
102
+ ax2.spines['top'].set_visible(False)
103
+ ax2.spines['right'].set_visible(False)
104
+ plt.tight_layout()
 
 
 
 
105
 
106
+ avg_score = np.mean(anomaly_preds)
107
+ high_risk_count = np.sum(anomaly_preds > np.percentile(anomaly_preds, 75))
108
+ anomaly_result = f"Average Anomaly Score: {avg_score:.3f}\nHigh Risk Companies: {high_risk_count}/{len(anomaly_preds)}\nScore Range: {np.min(anomaly_preds):.3f} - {np.max(anomaly_preds):.3f}"
 
 
 
 
 
 
 
109
 
110
+ # 3. LSTM REVENUE FORECASTING
111
+ # Extract revenue data from Q1_REVENUES to Q10_REVENUES
112
+ revenue_cols = [f'Q{i}_REVENUES' for i in range(1, 11)]
113
+ missing_cols = [col for col in revenue_cols if col not in df.columns]
114
 
115
+ if missing_cols:
116
+ lstm_result = f"Missing revenue columns for LSTM: {missing_cols}"
117
+ fig3 = plt.figure(figsize=(10, 6), facecolor='#1f1f1f')
118
+ ax3 = fig3.add_subplot(111, facecolor='#1f1f1f')
119
+ ax3.text(0.5, 0.5, 'Revenue columns not found in dataset',
120
+ ha='center', va='center', color='white', fontsize=14)
121
+ ax3.set_xlim(0, 1)
122
+ ax3.set_ylim(0, 1)
123
+ ax3.axis('off')
124
+ else:
125
+ # Use first company's revenue data for LSTM prediction
126
+ revenue_data = df[revenue_cols].iloc[0].values.astype(float)
127
+
128
+ # Handle missing values in revenue data
129
+ if np.any(np.isnan(revenue_data)) or np.any(revenue_data == 0):
130
+ # Replace NaN and zeros with interpolated values
131
+ mask = ~np.isnan(revenue_data) & (revenue_data != 0)
132
+ if np.sum(mask) > 1:
133
+ revenue_data[~mask] = np.interp(np.where(~mask)[0], np.where(mask)[0], revenue_data[mask])
134
+ else:
135
+ revenue_data = np.full_like(revenue_data, np.mean(revenue_data[mask]) if np.sum(mask) > 0 else 1000000)
136
+
137
+ revenue_data = revenue_data.reshape(1, -1)
138
+
139
+ # Scale and predict
140
+ revenue_scaled = scaler_X.transform(revenue_data).reshape((1, revenue_data.shape[1], 1))
141
+ pred_scaled = lstm_model.predict(revenue_scaled)
142
+ predicted_revenue = scaler_y.inverse_transform(pred_scaled)[0, 0]
143
+
144
+ # Create LSTM visualization
145
+ fig3, ax3 = plt.subplots(figsize=(12, 6), facecolor='#1f1f1f')
146
+ ax3.set_facecolor('#1f1f1f')
147
+
148
+ quarters = [f'Q{i}' for i in range(1, 11)]
149
+ ax3.plot(quarters, revenue_data.flatten(), marker='o', linewidth=2,
150
+ markersize=8, color='#2196F3', label='Historical Revenue')
151
+ ax3.plot('Q11', predicted_revenue, marker='X', markersize=15, color='#FF5722',
152
+ label=f'Predicted Q11: ${predicted_revenue:,.0f}')
153
+
154
+ ax3.set_xlabel('Quarter', color='white')
155
+ ax3.set_ylabel('Revenue ($)', color='white')
156
+ ax3.set_title('Revenue Forecast - Next Quarter Prediction', color='white', fontsize=14)
157
+ ax3.legend(facecolor='#2f2f2f', edgecolor='white', labelcolor='white')
158
+ ax3.tick_params(colors='white')
159
+ ax3.spines['bottom'].set_color('white')
160
+ ax3.spines['left'].set_color('white')
161
+ ax3.spines['top'].set_visible(False)
162
+ ax3.spines['right'].set_visible(False)
163
+ ax3.grid(True, alpha=0.3, color='white')
164
+
165
+ plt.xticks(rotation=45)
166
+ plt.tight_layout()
167
+
168
+ # Calculate growth rate
169
+ last_revenue = revenue_data.flatten()[-1]
170
+ growth_rate = ((predicted_revenue - last_revenue) / last_revenue) * 100
171
+ lstm_result = f"Predicted Q11 Revenue: ${predicted_revenue:,.0f}\nGrowth from Q10: {growth_rate:+.1f}%\nLast Quarter (Q10): ${last_revenue:,.0f}"
172
 
173
+ return bankruptcy_result, fig1, anomaly_result, fig2, lstm_result, fig3
174
 
175
  except Exception as e:
176
+ error_msg = f"Error in prediction: {str(e)}"
177
+ return error_msg, None, error_msg, None, error_msg, None
178
 
179
  # Custom CSS for proper dark mode support
180
  custom_css = """
 
225
  color: #ffffff !important;
226
  }
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  /* Text and markdown */
229
  .gr-markdown {
230
  color: #ffffff !important;
 
244
  with gr.Blocks(css=custom_css, theme=gr.themes.Base(), title="TriCast AI") as demo:
245
  gr.Markdown("""
246
  # πŸš€ TriCast AI
247
+ ### Comprehensive Financial Intelligence Platform
248
+ Upload your company's financial data CSV file to get AI-powered insights across three key areas **simultaneously**.
249
  """)
250
 
251
  gr.Markdown("""
252
+ **πŸ“ Expected CSV Format:**
253
+ Your CSV should contain financial metrics including:
254
+ - Basic info: `industry`, `sector`, `fullTimeEmployees`
255
+ - Risk metrics: `auditRisk`, `boardRisk`, `compensationRisk`, etc.
256
+ - Financial ratios: `trailingPE`, `forwardPE`, `totalDebt`, `totalRevenue`, etc.
257
+ - Quarterly data: `Q1_REVENUES`, `Q2_REVENUES`, ..., `Q10_REVENUES` (for LSTM forecasting)
258
+ - Quarterly financials: `Q*_TOTAL_ASSETS`, `Q*_TOTAL_LIABILITIES`, etc.
259
 
260
+ πŸ“Š **One Upload = Three AI Models Running Simultaneously!**
261
+ """)
 
 
 
 
 
 
 
 
262
 
263
+ with gr.Row():
264
+ with gr.Column(scale=1):
265
+ file_input = gr.File(
266
+ label="πŸ“ Upload Company Financial Data (CSV)",
267
+ file_types=[".csv"],
268
+ elem_id="file_upload"
269
+ )
270
+ analyze_btn = gr.Button(
271
+ "πŸš€ Run TriCast AI Analysis",
272
+ variant="primary",
273
+ size="lg"
274
+ )
275
 
276
+ gr.Markdown("---")
 
 
 
 
 
 
 
 
 
277
 
278
+
279
+ # Results section with three columns
280
+ with gr.Row():
281
+ with gr.Column():
282
+ gr.Markdown("### 🏦 Bankruptcy Risk Assessment")
283
+ bankruptcy_output = gr.Textbox(
284
+ label="Risk Analysis",
285
+ lines=4,
286
+ placeholder="Results will appear here..."
287
+ )
288
+ bankruptcy_plot = gr.Plot(label="Risk Visualization")
289
 
290
+ with gr.Column():
291
+ gr.Markdown("### πŸ“Š Anomaly Detection")
292
+ anomaly_output = gr.Textbox(
293
+ label="Anomaly Analysis",
294
+ lines=4,
295
+ placeholder="Results will appear here..."
296
+ )
297
+ anomaly_plot = gr.Plot(label="Score Distribution")
298
 
299
+ with gr.Column():
300
+ gr.Markdown("### πŸ“ˆ Revenue Forecasting")
301
+ lstm_output = gr.Textbox(
302
+ label="Forecast Summary",
303
+ lines=4,
304
+ placeholder="Results will appear here..."
305
+ )
306
+ lstm_plot = gr.Plot(label="Revenue Forecast")
 
307
 
308
  if __name__ == "__main__":
309
+ demo.launch()