AbdullahImran commited on
Commit
a235e07
Β·
verified Β·
1 Parent(s): 5b00a4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +308 -77
app.py CHANGED
@@ -7,91 +7,322 @@ from tensorflow.keras.models import load_model
7
  from sklearn.preprocessing import MinMaxScaler
8
  import matplotlib.pyplot as plt
9
  import seaborn as sns
 
10
 
11
  # Load models & scalers
12
  xgb_clf = xgb.XGBClassifier()
13
  xgb_clf.load_model("xgb_model.json")
14
-
15
  xgb_reg = joblib.load("xgb_pipeline_model.pkl")
16
-
17
  scaler_X = joblib.load("scaler_X.pkl")
18
  scaler_y = joblib.load("scaler_y.pkl")
19
-
20
  lstm_model = load_model("lstm_revenue_model.keras")
21
 
22
- # Prediction + Plot functions
23
- def classify_fn(df: pd.DataFrame):
24
- preds = xgb_clf.predict(df)
25
- probs = xgb_clf.predict_proba(df)
26
- fig, ax = plt.subplots()
27
- ax.bar(['No Bankruptcy', 'Bankruptcy'], probs[0], color=['#4CAF50', '#F44336'])
28
- ax.set_ylim(0, 1)
29
- ax.set_title('Bankruptcy Probability')
30
- ax.set_ylabel('Probability')
31
- plt.tight_layout()
32
- return {"Predicted Label": int(preds[0])}, fig
33
-
34
-
35
- def regress_fn(df: pd.DataFrame):
36
- preds = xgb_reg.predict(df)
37
- fig, ax = plt.subplots()
38
- sns.histplot(preds, bins=20, kde=True, ax=ax)
39
- ax.set_title('Anomaly Score Distribution')
40
- ax.set_xlabel('Predicted Anomaly Score')
41
- plt.tight_layout()
42
- return preds.tolist(), fig
43
-
44
-
45
- def lstm_fn(seq_str: str):
46
- vals = np.array(list(map(float, seq_str.split(',')))).reshape(1, -1)
47
- vals_s = scaler_X.transform(vals).reshape((1, vals.shape[1], 1))
48
- pred_s = lstm_model.predict(vals_s)
49
- pred = scaler_y.inverse_transform(pred_s)[0, 0]
50
- fig, ax = plt.subplots()
51
- ax.plot(range(10), vals.flatten(), marker='o', label='Input Revenue')
52
- ax.plot(10, pred, marker='X', markersize=10, color='red', label='Predicted Q10')
53
- ax.set_xlabel('Quarter Index (0-10)')
54
- ax.set_ylabel('Revenue')
55
- ax.set_title('Revenue Forecast')
56
- ax.legend()
57
- plt.tight_layout()
58
- return float(pred), fig
59
-
60
- # Build UI
61
- grid_css = """
62
- body {background-color: #f7f7f7;}
63
- .gradio-container {max-width: 800px; margin: auto; padding: 20px;}
64
- h1, h2 {color: #333;}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  """
66
 
67
- demo = gr.Blocks(css=grid_css)
68
- with demo:
69
- gr.Markdown("# πŸš€ FinSight 360β„’ Dashboard")
70
- gr.Markdown("Comprehensive financial AI:\n- Bankruptcy Classification\n- Anomaly Scoring\n- Revenue Forecasting")
71
-
72
- with gr.Tab("🏦 Bankruptcy Classifier"):
73
- gr.Markdown("**Upload company features** (as DataFrame) to predict bankruptcy:")
74
- inp1 = gr.Dataframe(type="pandas", label="Features DataFrame")
75
- classify_btn = gr.Button("Run Classification")
76
- out1 = gr.Label(label="Predicted Label")
77
- plt1 = gr.Plot()
78
- classify_btn.click(fn=classify_fn, inputs=inp1, outputs=[out1, plt1])
79
-
80
- with gr.Tab("πŸ“ˆ Anomaly Regression"):
81
- gr.Markdown("**Upload company features** (as DataFrame) to predict anomaly score:")
82
- inp2 = gr.Dataframe(type="pandas", label="Features DataFrame")
83
- regress_btn = gr.Button("Run Regression")
84
- out2 = gr.Textbox(label="Predicted Scores List")
85
- plt2 = gr.Plot()
86
- regress_btn.click(fn=regress_fn, inputs=inp2, outputs=[out2, plt2])
87
-
88
- with gr.Tab("πŸ“Š LSTM Revenue Forecast"):
89
- gr.Markdown("**Enter last 10 quarterly revenues** (comma-separated) to forecast Q10 revenue:")
90
- inp3 = gr.Textbox(placeholder="e.g. 1000,1200,1100,...", label="Q0–Q9 Revenues")
91
- out3 = gr.Number(label="Predicted Q10 Revenue")
92
- plt3 = gr.Plot()
93
- inp3.submit(fn=lstm_fn, inputs=inp3, outputs=[out3, plt3])
94
-
95
- gr.Markdown("---\n*Industry, Innovation and Infrastructure*")
96
-
97
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from sklearn.preprocessing import MinMaxScaler
8
  import matplotlib.pyplot as plt
9
  import seaborn as sns
10
+ import io
11
 
12
  # Load models & scalers
13
  xgb_clf = xgb.XGBClassifier()
14
  xgb_clf.load_model("xgb_model.json")
 
15
  xgb_reg = joblib.load("xgb_pipeline_model.pkl")
 
16
  scaler_X = joblib.load("scaler_X.pkl")
17
  scaler_y = joblib.load("scaler_y.pkl")
 
18
  lstm_model = load_model("lstm_revenue_model.keras")
19
 
20
+ # Set matplotlib style for dark theme compatibility
21
+ plt.style.use('dark_background')
22
+
23
+ def process_csv_file(file):
24
+ """Process uploaded CSV file and return DataFrame"""
25
+ if file is None:
26
+ return None
27
+ try:
28
+ df = pd.read_csv(file.name)
29
+ return df
30
+ except Exception as e:
31
+ gr.Warning(f"Error reading CSV file: {str(e)}")
32
+ return None
33
+
34
+ def classify_fn(file):
35
+ """Bankruptcy classification from CSV file"""
36
+ if file is None:
37
+ return "Please upload a CSV file", None
38
+
39
+ df = process_csv_file(file)
40
+ if df is None:
41
+ return "Error processing file", None
42
+
43
+ try:
44
+ # Use all rows in the CSV for prediction
45
+ preds = xgb_clf.predict(df)
46
+ probs = xgb_clf.predict_proba(df)
47
+
48
+ # Create visualization
49
+ fig, ax = plt.subplots(figsize=(10, 6), facecolor='#1f1f1f')
50
+ ax.set_facecolor('#1f1f1f')
51
+
52
+ if len(preds) == 1:
53
+ # Single company prediction
54
+ bars = ax.bar(['No Bankruptcy', 'Bankruptcy'], probs[0],
55
+ color=['#4CAF50', '#F44336'], alpha=0.8)
56
+ ax.set_ylim(0, 1)
57
+ ax.set_title('Bankruptcy Probability', color='white', fontsize=14)
58
+ ax.set_ylabel('Probability', color='white')
59
+ result_text = f"Prediction: {'Bankruptcy Risk' if preds[0] == 1 else 'No Bankruptcy Risk'}\nConfidence: {max(probs[0]):.2%}"
60
+ else:
61
+ # Multiple companies
62
+ bankruptcy_count = np.sum(preds)
63
+ safe_count = len(preds) - bankruptcy_count
64
+ bars = ax.bar(['Safe Companies', 'At Risk Companies'],
65
+ [safe_count, bankruptcy_count],
66
+ color=['#4CAF50', '#F44336'], alpha=0.8)
67
+ ax.set_title(f'Bankruptcy Analysis for {len(preds)} Companies', color='white', fontsize=14)
68
+ ax.set_ylabel('Number of Companies', color='white')
69
+ result_text = f"Total Companies: {len(preds)}\nSafe: {safe_count}\nAt Risk: {bankruptcy_count}"
70
+
71
+ ax.tick_params(colors='white')
72
+ ax.spines['bottom'].set_color('white')
73
+ ax.spines['left'].set_color('white')
74
+ ax.spines['top'].set_visible(False)
75
+ ax.spines['right'].set_visible(False)
76
+
77
+ plt.tight_layout()
78
+ return result_text, fig
79
+
80
+ except Exception as e:
81
+ return f"Error in prediction: {str(e)}", None
82
+
83
+ def regress_fn(file):
84
+ """Anomaly detection from CSV file"""
85
+ if file is None:
86
+ return "Please upload a CSV file", None
87
+
88
+ df = process_csv_file(file)
89
+ if df is None:
90
+ return "Error processing file", None
91
+
92
+ try:
93
+ preds = xgb_reg.predict(df)
94
+
95
+ # Create visualization
96
+ fig, ax = plt.subplots(figsize=(10, 6), facecolor='#1f1f1f')
97
+ ax.set_facecolor('#1f1f1f')
98
+
99
+ sns.histplot(preds, bins=20, kde=True, ax=ax, color='#00BCD4', alpha=0.7)
100
+ ax.set_title('Anomaly Score Distribution', color='white', fontsize=14)
101
+ ax.set_xlabel('Anomaly Score', color='white')
102
+ ax.set_ylabel('Frequency', color='white')
103
+ ax.tick_params(colors='white')
104
+ ax.spines['bottom'].set_color('white')
105
+ ax.spines['left'].set_color('white')
106
+ ax.spines['top'].set_visible(False)
107
+ ax.spines['right'].set_visible(False)
108
+
109
+ plt.tight_layout()
110
+
111
+ # Summary statistics
112
+ avg_score = np.mean(preds)
113
+ high_risk_count = np.sum(preds > np.percentile(preds, 75))
114
+ result_text = f"Average Anomaly Score: {avg_score:.3f}\nHigh Risk Companies: {high_risk_count}/{len(preds)}\nScore Range: {np.min(preds):.3f} - {np.max(preds):.3f}"
115
+
116
+ return result_text, fig
117
+
118
+ except Exception as e:
119
+ return f"Error in prediction: {str(e)}", None
120
+
121
+ def lstm_fn(file):
122
+ """LSTM revenue forecasting from CSV file"""
123
+ if file is None:
124
+ return "Please upload a CSV file", None
125
+
126
+ df = process_csv_file(file)
127
+ if df is None:
128
+ return "Error processing file", None
129
+
130
+ try:
131
+ # Expect CSV with revenue columns or a single row with 10 revenue values
132
+ if df.shape[1] < 10:
133
+ return "CSV must contain at least 10 revenue columns for quarterly data", None
134
+
135
+ # Take first row and first 10 columns as revenue sequence
136
+ vals = df.iloc[0, :10].values.astype(float).reshape(1, -1)
137
+
138
+ # Scale and predict
139
+ vals_s = scaler_X.transform(vals).reshape((1, vals.shape[1], 1))
140
+ pred_s = lstm_model.predict(vals_s)
141
+ pred = scaler_y.inverse_transform(pred_s)[0, 0]
142
+
143
+ # Create visualization
144
+ fig, ax = plt.subplots(figsize=(12, 6), facecolor='#1f1f1f')
145
+ ax.set_facecolor('#1f1f1f')
146
+
147
+ quarters = [f'Q{i+1}' for i in range(10)]
148
+ ax.plot(quarters, vals.flatten(), marker='o', linewidth=2,
149
+ markersize=8, color='#2196F3', label='Historical Revenue')
150
+ ax.plot('Q11', pred, marker='X', markersize=15, color='#FF5722',
151
+ label=f'Predicted Q11: ${pred:,.0f}')
152
+
153
+ ax.set_xlabel('Quarter', color='white')
154
+ ax.set_ylabel('Revenue ($)', color='white')
155
+ ax.set_title('Revenue Forecast - Next Quarter Prediction', color='white', fontsize=14)
156
+ ax.legend(facecolor='#2f2f2f', edgecolor='white', labelcolor='white')
157
+ ax.tick_params(colors='white')
158
+ ax.spines['bottom'].set_color('white')
159
+ ax.spines['left'].set_color('white')
160
+ ax.spines['top'].set_visible(False)
161
+ ax.spines['right'].set_visible(False)
162
+ ax.grid(True, alpha=0.3, color='white')
163
+
164
+ plt.xticks(rotation=45)
165
+ plt.tight_layout()
166
+
167
+ # Calculate growth rate
168
+ last_revenue = vals.flatten()[-1]
169
+ growth_rate = ((pred - last_revenue) / last_revenue) * 100
170
+ result_text = f"Predicted Q11 Revenue: ${pred:,.0f}\nGrowth from Q10: {growth_rate:+.1f}%"
171
+
172
+ return result_text, fig
173
+
174
+ except Exception as e:
175
+ return f"Error in prediction: {str(e)}", None
176
+
177
+ # Custom CSS for proper dark mode support
178
+ custom_css = """
179
+ /* Dark theme for the entire interface */
180
+ .gradio-container {
181
+ background-color: #1a1a1a !important;
182
+ color: #ffffff !important;
183
+ }
184
+
185
+ .gr-box {
186
+ background-color: #2d2d2d !important;
187
+ border: 1px solid #404040 !important;
188
+ }
189
+
190
+ .gr-form {
191
+ background-color: #2d2d2d !important;
192
+ }
193
+
194
+ .gr-panel {
195
+ background-color: #2d2d2d !important;
196
+ border: 1px solid #404040 !important;
197
+ }
198
+
199
+ .gr-button {
200
+ background-color: #0066cc !important;
201
+ color: white !important;
202
+ border: none !important;
203
+ }
204
+
205
+ .gr-button:hover {
206
+ background-color: #0052a3 !important;
207
+ }
208
+
209
+ .gr-input, .gr-textbox {
210
+ background-color: #2d2d2d !important;
211
+ border: 1px solid #404040 !important;
212
+ color: #ffffff !important;
213
+ }
214
+
215
+ .gr-upload {
216
+ background-color: #2d2d2d !important;
217
+ border: 2px dashed #404040 !important;
218
+ color: #ffffff !important;
219
+ }
220
+
221
+ .gr-file {
222
+ background-color: #2d2d2d !important;
223
+ color: #ffffff !important;
224
+ }
225
+
226
+ /* Tab styling */
227
+ .gr-tab-nav {
228
+ background-color: #2d2d2d !important;
229
+ border-bottom: 1px solid #404040 !important;
230
+ }
231
+
232
+ .gr-tab-nav button {
233
+ background-color: transparent !important;
234
+ color: #ffffff !important;
235
+ border: none !important;
236
+ }
237
+
238
+ .gr-tab-nav button.selected {
239
+ background-color: #0066cc !important;
240
+ color: white !important;
241
+ }
242
+
243
+ /* Text and markdown */
244
+ .gr-markdown {
245
+ color: #ffffff !important;
246
+ }
247
+
248
+ .gr-markdown h1, .gr-markdown h2, .gr-markdown h3 {
249
+ color: #ffffff !important;
250
+ }
251
+
252
+ /* Ensure plot backgrounds work with dark theme */
253
+ .gr-plot {
254
+ background-color: #1f1f1f !important;
255
+ }
256
  """
257
 
258
+ # Create the Gradio interface
259
+ with gr.Blocks(css=custom_css, theme=gr.themes.Base(), title="TriCast AI") as demo:
260
+ gr.Markdown("""
261
+ # πŸš€ TriCast AI
262
+ ### Advanced Financial Intelligence Platform
263
+ Upload your company's financial data as a CSV file to get comprehensive AI-powered insights across three key areas.
264
+ """)
265
+
266
+ gr.Markdown("""
267
+ **πŸ“ CSV File Format Guidelines:**
268
+ - **Bankruptcy & Anomaly Detection**: Include financial metrics as columns (revenue, debt, assets, etc.)
269
+ - **Revenue Forecasting**: First 10 columns should contain quarterly revenue data
270
+ - Each row represents one company's data
271
+ """)
272
+
273
+ with gr.Tab("🏦 Bankruptcy Risk Assessment"):
274
+ gr.Markdown("**Upload CSV with company financial data to assess bankruptcy risk**")
275
+ with gr.Row():
276
+ with gr.Column():
277
+ file1 = gr.File(label="Upload CSV File", file_types=[".csv"])
278
+ classify_btn = gr.Button("πŸ” Analyze Bankruptcy Risk", variant="primary")
279
+ with gr.Column():
280
+ out1 = gr.Textbox(label="Analysis Results", lines=4)
281
+ plt1 = gr.Plot(label="Risk Visualization")
282
+ classify_btn.click(fn=classify_fn, inputs=file1, outputs=[out1, plt1])
283
+
284
+ with gr.Tab("πŸ“Š Anomaly Detection"):
285
+ gr.Markdown("**Upload CSV with company financial data to detect anomalies**")
286
+ with gr.Row():
287
+ with gr.Column():
288
+ file2 = gr.File(label="Upload CSV File", file_types=[".csv"])
289
+ regress_btn = gr.Button("πŸ”Ž Detect Anomalies", variant="primary")
290
+ with gr.Column():
291
+ out2 = gr.Textbox(label="Anomaly Analysis", lines=4)
292
+ plt2 = gr.Plot(label="Score Distribution")
293
+ regress_btn.click(fn=regress_fn, inputs=file2, outputs=[out2, plt2])
294
+
295
+ with gr.Tab("πŸ“ˆ Revenue Forecasting"):
296
+ gr.Markdown("**Upload CSV with quarterly revenue data (10 quarters) to forecast next quarter**")
297
+ with gr.Row():
298
+ with gr.Column():
299
+ file3 = gr.File(label="Upload CSV File", file_types=[".csv"])
300
+ forecast_btn = gr.Button("πŸ“Š Forecast Revenue", variant="primary")
301
+ with gr.Column():
302
+ out3 = gr.Textbox(label="Forecast Results", lines=4)
303
+ plt3 = gr.Plot(label="Revenue Trend & Prediction")
304
+ forecast_btn.click(fn=lstm_fn, inputs=file3, outputs=[out3, plt3])
305
+
306
+ with gr.Tab("πŸ“‹ Sample Data Format"):
307
+ gr.Markdown("""
308
+ ### Sample CSV Formats:
309
+
310
+ **For Bankruptcy & Anomaly Detection:**
311
+ ```
312
+ company_name,total_assets,total_liabilities,revenue,debt_ratio,current_ratio
313
+ Company A,1000000,500000,800000,0.5,2.1
314
+ Company B,2000000,1800000,600000,0.9,0.8
315
+ ```
316
+
317
+ **For Revenue Forecasting:**
318
+ ```
319
+ q1_revenue,q2_revenue,q3_revenue,q4_revenue,q5_revenue,q6_revenue,q7_revenue,q8_revenue,q9_revenue,q10_revenue
320
+ 100000,120000,110000,130000,125000,140000,135000,150000,145000,160000
321
+ ```
322
+ """)
323
+
324
+ gr.Markdown("---")
325
+ gr.Markdown("*TriCast AI - Powered by Advanced Machine Learning | Industry, Innovation and Infrastructure*")
326
+
327
+ if __name__ == "__main__":
328
+ demo.launch()