mgbam commited on
Commit
6e8a7d4
ยท
1 Parent(s): bdbd063

Add application file

Browse files
Files changed (1) hide show
  1. app.py +483 -0
app.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import streamlit as st
4
+ import numpy as np
5
+ import pandas as pd
6
+ from smolagents import CodeAgent, tool
7
+ from typing import Union, List, Dict, Optional
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ import plotly.express as px
11
+ import plotly.graph_objects as go
12
+ import os
13
+ from groq import Groq
14
+ from dataclasses import dataclass
15
+ import tempfile
16
+ import base64
17
+ import io
18
+ from sklearn.model_selection import train_test_split
19
+ from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
20
+ import joblib
21
+ import pdfkit # Ensure wkhtmltopdf is available in the environment
22
+ import uuid # For generating unique report IDs
23
+
24
+ # ------------------------------
25
+ # Language Model Interface
26
+ # ------------------------------
27
+ class GroqLLM:
28
+ """Enhanced LLM interface with support for generating natural language summaries."""
29
+ def __init__(self, model_name="llama-3.1-8B-Instant"):
30
+ self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
31
+ self.model_name = model_name
32
+
33
+ def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
34
+ """Make the class callable as required by smolagents"""
35
+ try:
36
+ # Handle different prompt formats
37
+ if isinstance(prompt, (dict, list)):
38
+ prompt_str = str(prompt)
39
+ else:
40
+ prompt_str = str(prompt)
41
+
42
+ # Create a properly formatted message
43
+ completion = self.client.chat.completions.create(
44
+ model=self.model_name,
45
+ messages=[{
46
+ "role": "user",
47
+ "content": prompt_str
48
+ }],
49
+ temperature=0.7,
50
+ max_tokens=1500, # Increased tokens for detailed responses
51
+ stream=False
52
+ )
53
+
54
+ return completion.choices[0].message.content if completion.choices else "Error: No response generated"
55
+
56
+ except Exception as e:
57
+ error_msg = f"Error generating response: {str(e)}"
58
+ print(error_msg)
59
+ return error_msg
60
+
61
+ # ------------------------------
62
+ # Data Analysis Agent
63
+ # ------------------------------
64
+ class DataAnalysisAgent(CodeAgent):
65
+ """Extended CodeAgent with dataset awareness and predictive analytics capabilities."""
66
+ def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
67
+ super().__init__(*args, **kwargs)
68
+ self._dataset = dataset
69
+ self.models = {} # To store trained models
70
+
71
+ @property
72
+ def dataset(self) -> pd.DataFrame:
73
+ """Access the stored dataset"""
74
+ return self._dataset
75
+
76
+ def run(self, prompt: str) -> str:
77
+ """Override run method to include dataset context and support predictive tasks"""
78
+ dataset_info = f"""
79
+ Dataset Shape: {self.dataset.shape}
80
+ Columns: {', '.join(self.dataset.columns)}
81
+ Data Types: {self.dataset.dtypes.to_dict()}
82
+ """
83
+ enhanced_prompt = f"""
84
+ Analyze the following dataset:
85
+ {dataset_info}
86
+
87
+ Task: {prompt}
88
+
89
+ Use the provided tools to analyze this specific dataset and return detailed results.
90
+ """
91
+ return super().run(enhanced_prompt)
92
+
93
+ # ------------------------------
94
+ # Tool Definitions
95
+ # ------------------------------
96
+
97
+ @tool
98
+ def analyze_basic_stats(data: pd.DataFrame) -> str:
99
+ """Calculate and visualize basic statistical measures for numerical columns."""
100
+ if data is None:
101
+ data = tool.agent.dataset
102
+
103
+ stats = {}
104
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
105
+
106
+ for col in numeric_cols:
107
+ stats[col] = {
108
+ 'mean': float(data[col].mean()),
109
+ 'median': float(data[col].median()),
110
+ 'std': float(data[col].std()),
111
+ 'skew': float(data[col].skew()),
112
+ 'missing': int(data[col].isnull().sum())
113
+ }
114
+
115
+ # Generate a summary DataFrame
116
+ stats_df = pd.DataFrame(stats).T
117
+ stats_df.reset_index(inplace=True)
118
+ stats_df.rename(columns={'index': 'Feature'}, inplace=True)
119
+
120
+ # Plotting basic statistics
121
+ fig, ax = plt.subplots(figsize=(10, 6))
122
+ stats_df.set_index('Feature')[['mean', 'median', 'std']].plot(kind='bar', ax=ax)
123
+ plt.title('Basic Statistics')
124
+ plt.ylabel('Values')
125
+ plt.tight_layout()
126
+
127
+ # Save plot to buffer
128
+ buf = io.BytesIO()
129
+ plt.savefig(buf, format='png')
130
+ plt.close()
131
+ stats_plot = base64.b64encode(buf.getvalue()).decode()
132
+
133
+ return f"### Basic Statistics\n{stats_df.to_markdown()} \n\n![Basic Statistics](data:image/png;base64,{stats_plot})"
134
+
135
+ @tool
136
+ def generate_correlation_matrix(data: pd.DataFrame) -> str:
137
+ """Generate an interactive correlation matrix using Plotly."""
138
+ if data is None:
139
+ data = tool.agent.dataset
140
+
141
+ numeric_data = data.select_dtypes(include=[np.number])
142
+ corr = numeric_data.corr()
143
+
144
+ fig = px.imshow(corr,
145
+ text_auto=True,
146
+ aspect="auto",
147
+ color_continuous_scale='RdBu',
148
+ title='Correlation Matrix')
149
+
150
+ fig.update_layout(width=800, height=600)
151
+
152
+ # Convert Plotly figure to HTML div
153
+ correlation_html = fig.to_html(full_html=False)
154
+
155
+ return correlation_html
156
+
157
+ @tool
158
+ def analyze_categorical_columns(data: pd.DataFrame) -> str:
159
+ """Analyze categorical columns with visualizations."""
160
+ if data is None:
161
+ data = tool.agent.dataset
162
+
163
+ categorical_cols = data.select_dtypes(include=['object', 'category']).columns
164
+ analysis = {}
165
+ plots = ""
166
+
167
+ for col in categorical_cols:
168
+ unique_vals = data[col].nunique()
169
+ top_categories = data[col].value_counts().head(5).to_dict()
170
+ missing = data[col].isnull().sum()
171
+
172
+ analysis[col] = {
173
+ 'unique_values': int(unique_vals),
174
+ 'top_categories': top_categories,
175
+ 'missing': int(missing)
176
+ }
177
+
178
+ # Generate bar chart for top categories
179
+ fig, ax = plt.subplots(figsize=(8, 4))
180
+ sns.countplot(data=data, x=col, order=data[col].value_counts().iloc[:5].index, ax=ax)
181
+ plt.title(f'Top 5 Categories in {col}')
182
+ plt.xticks(rotation=45)
183
+ plt.tight_layout()
184
+
185
+ buf = io.BytesIO()
186
+ plt.savefig(buf, format='png')
187
+ plt.close()
188
+ plot_img = base64.b64encode(buf.getvalue()).decode()
189
+
190
+ plots += f"### {col}\n"
191
+ plots += f"- **Unique Values:** {unique_vals}\n"
192
+ plots += f"- **Missing Values:** {missing}\n"
193
+ plots += f"- **Top Categories:** {top_categories}\n"
194
+ plots += f"![Top Categories in {col}](data:image/png;base64,{plot_img})\n\n"
195
+
196
+ return plots + f"### Categorical Columns Analysis\n{pd.DataFrame(analysis).T.to_markdown()}"
197
+
198
+ @tool
199
+ def suggest_features(data: pd.DataFrame) -> str:
200
+ """Suggest potential feature engineering steps based on data characteristics."""
201
+ if data is None:
202
+ data = tool.agent.dataset
203
+
204
+ suggestions = []
205
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
206
+ categorical_cols = data.select_dtypes(include=['object', 'category']).columns
207
+
208
+ # Interaction terms
209
+ if len(numeric_cols) >= 2:
210
+ suggestions.append("โ€ข **Interaction Terms:** Consider creating interaction terms between numerical features to capture combined effects.")
211
+
212
+ # Encoding categorical variables
213
+ if len(categorical_cols) > 0:
214
+ suggestions.append("โ€ข **One-Hot Encoding:** Apply one-hot encoding to categorical variables to convert them into numerical format.")
215
+ suggestions.append("โ€ข **Label Encoding:** For ordinal categorical variables, consider label encoding to maintain order information.")
216
+
217
+ # Handling skewness
218
+ for col in numeric_cols:
219
+ if data[col].skew() > 1 or data[col].skew() < -1:
220
+ suggestions.append(f"โ€ข **Log Transformation:** Apply log transformation to `{col}` to reduce skewness and stabilize variance.")
221
+
222
+ # Missing value imputation
223
+ for col in data.columns:
224
+ if data[col].isnull().sum() > 0:
225
+ suggestions.append(f"โ€ข **Imputation:** Consider imputing missing values in `{col}` using mean, median, or advanced imputation techniques.")
226
+
227
+ # Feature scaling
228
+ suggestions.append("โ€ข **Feature Scaling:** Apply feature scaling (Standardization or Normalization) to numerical features to ensure uniformity.")
229
+
230
+ return "\n".join(suggestions)
231
+
232
+ @tool
233
+ def predictive_analysis(data: pd.DataFrame, target: str) -> str:
234
+ """Perform predictive analytics by training a classification model."""
235
+ if data is None:
236
+ data = tool.agent.dataset
237
+
238
+ if target not in data.columns:
239
+ return f"Error: Target column `{target}` not found in the dataset."
240
+
241
+ # Handle categorical target
242
+ if data[target].dtype == 'object' or data[target].dtype.name == 'category':
243
+ data[target] = data[target].astype('category').cat.codes
244
+
245
+ # Drop rows with missing target
246
+ data = data.dropna(subset=[target])
247
+
248
+ # Separate features and target
249
+ X = data.drop(columns=[target])
250
+ y = data[target]
251
+
252
+ # Handle missing values (simple imputation)
253
+ X = X.fillna(X.median())
254
+
255
+ # Encode categorical variables
256
+ X = pd.get_dummies(X, drop_first=True)
257
+
258
+ # Split data
259
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
260
+
261
+ # Train a Random Forest Classifier (as an example)
262
+ from sklearn.ensemble import RandomForestClassifier
263
+ clf = RandomForestClassifier(n_estimators=100, random_state=42)
264
+ clf.fit(X_train, y_train)
265
+
266
+ # Predictions
267
+ y_pred = clf.predict(X_test)
268
+ y_proba = clf.predict_proba(X_test)[:,1]
269
+
270
+ # Evaluation
271
+ report = classification_report(y_test, y_pred, output_dict=True)
272
+ report_df = pd.DataFrame(report).transpose()
273
+
274
+ # Confusion Matrix
275
+ cm = confusion_matrix(y_test, y_pred)
276
+ fig_cm = px.imshow(cm, text_auto=True, labels=dict(x="Predicted", y="Actual", color="Count"),
277
+ x=["Negative", "Positive"], y=["Negative", "Positive"],
278
+ title="Confusion Matrix")
279
+
280
+ # ROC Curve
281
+ fpr, tpr, thresholds = roc_curve(y_test, y_proba)
282
+ roc_auc = auc(fpr, tpr)
283
+ fig_roc = go.Figure()
284
+ fig_roc.add_trace(go.Scatter(x=fpr, y=tpr, mode='lines', name=f'ROC Curve (AUC = {roc_auc:.2f})'))
285
+ fig_roc.add_trace(go.Scatter(x=[0,1], y=[0,1], mode='lines', name='Random Guess', line=dict(dash='dash')))
286
+ fig_roc.update_layout(title='Receiver Operating Characteristic (ROC) Curve',
287
+ xaxis_title='False Positive Rate',
288
+ yaxis_title='True Positive Rate')
289
+
290
+ # Save models for potential future use
291
+ model_id = str(uuid.uuid4())
292
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.joblib') as tmp_model_file:
293
+ joblib.dump(clf, tmp_model_file.name)
294
+ # In a real-world scenario, you'd store this in a persistent storage
295
+ tool.agent.models[model_id] = clf # Storing in agent's models dict
296
+
297
+ # Generate HTML for plots
298
+ cm_html = fig_cm.to_html(full_html=False)
299
+ roc_html = fig_roc.to_html(full_html=False)
300
+
301
+ # Generate report summary
302
+ summary = f"""
303
+ ### Predictive Analytics Report for Target: `{target}`
304
+
305
+ **Model Used:** Random Forest Classifier
306
+
307
+ **Classification Report:**
308
+ {report_df.to_markdown()}
309
+
310
+ **Confusion Matrix:**
311
+ {cm_html}
312
+
313
+ **ROC Curve:**
314
+ {roc_html}
315
+
316
+ **AUC Score:** {roc_auc:.2f}
317
+
318
+ **Model ID:** `{model_id}`
319
+
320
+ *You can use this Model ID to retrieve or update the model in future analyses.*
321
+ """
322
+
323
+ return summary
324
+
325
+ # ------------------------------
326
+ # Report Exporting Function
327
+ # ------------------------------
328
+ def export_report(content: str, filename: str):
329
+ """Export the given content as a PDF report."""
330
+ # Save content to a temporary HTML file
331
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.html') as tmp_file:
332
+ tmp_file.write(content.encode('utf-8'))
333
+ tmp_file_path = tmp_file.name
334
+
335
+ # Define output PDF path
336
+ pdf_path = f"{filename}.pdf"
337
+
338
+ # Convert HTML to PDF using pdfkit
339
+ try:
340
+ # Configure pdfkit options for HuggingFace Spaces environment
341
+ config = pdfkit.configuration()
342
+ pdfkit.from_file(tmp_file_path, pdf_path, configuration=config)
343
+ with open(pdf_path, "rb") as pdf_file:
344
+ PDFbyte = pdf_file.read()
345
+
346
+ # Provide download link
347
+ st.download_button(label="๐Ÿ“ฅ Download Report as PDF",
348
+ data=PDFbyte,
349
+ file_name=pdf_path,
350
+ mime='application/octet-stream')
351
+ except Exception as e:
352
+ st.error(f"โš ๏ธ Error exporting report: {str(e)}")
353
+ finally:
354
+ os.remove(tmp_file_path)
355
+ if os.path.exists(pdf_path):
356
+ os.remove(pdf_path)
357
+
358
+ # ------------------------------
359
+ # Main Application Function
360
+ # ------------------------------
361
+ def main():
362
+ st.set_page_config(page_title="๐Ÿ“Š Business Intelligence Assistant", layout="wide")
363
+ st.title("๐Ÿ“Š **Business Intelligence Assistant**")
364
+ st.write("Upload your dataset and receive comprehensive analyses, interactive visualizations, and predictive insights.")
365
+
366
+ # Initialize session state
367
+ if 'data' not in st.session_state:
368
+ st.session_state['data'] = None
369
+ if 'agent' not in st.session_state:
370
+ st.session_state['agent'] = None
371
+ if 'report_content' not in st.session_state:
372
+ st.session_state['report_content'] = ""
373
+
374
+ # File Uploader
375
+ uploaded_file = st.file_uploader("๐Ÿ“ฅ **Upload a CSV file**", type="csv")
376
+
377
+ try:
378
+ if uploaded_file is not None:
379
+ with st.spinner('๐Ÿ”„ Loading and processing your data...'):
380
+ # Load the dataset
381
+ data = pd.read_csv(uploaded_file)
382
+ st.session_state['data'] = data
383
+
384
+ # Initialize the agent with the dataset
385
+ st.session_state['agent'] = DataAnalysisAgent(
386
+ dataset=data,
387
+ tools=[analyze_basic_stats, generate_correlation_matrix,
388
+ analyze_categorical_columns, suggest_features, predictive_analysis],
389
+ model=GroqLLM(),
390
+ additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn", "plotly"]
391
+ )
392
+
393
+ st.success(f"โœ… Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns")
394
+ st.subheader("๐Ÿ” **Data Preview**")
395
+ st.dataframe(data.head())
396
+
397
+ if st.session_state['data'] is not None:
398
+ # Sidebar for Analysis Selection
399
+ st.sidebar.header("๐Ÿ› ๏ธ **Select Analysis Type**")
400
+ analysis_type = st.sidebar.selectbox(
401
+ "Choose analysis type",
402
+ ["Basic Statistics", "Correlation Analysis", "Categorical Analysis",
403
+ "Feature Engineering", "Predictive Analytics", "Custom Question"]
404
+ )
405
+
406
+ if analysis_type == "Basic Statistics":
407
+ with st.spinner('๐Ÿ“ˆ Analyzing basic statistics...'):
408
+ result = st.session_state['agent'].run(
409
+ "Use the analyze_basic_stats tool to analyze this dataset and "
410
+ "provide insights about the numerical distributions."
411
+ )
412
+ st.markdown(result, unsafe_allow_html=True)
413
+ st.session_state['report_content'] += result + "\n\n"
414
+
415
+ elif analysis_type == "Correlation Analysis":
416
+ with st.spinner('๐Ÿ“Š Generating correlation matrix...'):
417
+ result = st.session_state['agent'].run(
418
+ "Use the generate_correlation_matrix tool to analyze correlations "
419
+ "and explain any strong relationships found."
420
+ )
421
+ st.components.v1.html(result, height=600)
422
+ st.session_state['report_content'] += "### Correlation Analysis\n" + result + "\n\n"
423
+
424
+ elif analysis_type == "Categorical Analysis":
425
+ with st.spinner('๐Ÿ“Š Analyzing categorical columns...'):
426
+ result = st.session_state['agent'].run(
427
+ "Use the analyze_categorical_columns tool to examine the "
428
+ "categorical variables and explain the distributions."
429
+ )
430
+ st.markdown(result, unsafe_allow_html=True)
431
+ st.session_state['report_content'] += result + "\n\n"
432
+
433
+ elif analysis_type == "Feature Engineering":
434
+ with st.spinner('๐Ÿ”ง Generating feature suggestions...'):
435
+ result = st.session_state['agent'].run(
436
+ "Use the suggest_features tool to recommend potential "
437
+ "feature engineering steps for this dataset."
438
+ )
439
+ st.markdown(result, unsafe_allow_html=True)
440
+ st.session_state['report_content'] += result + "\n\n"
441
+
442
+ elif analysis_type == "Predictive Analytics":
443
+ with st.form("Predictive Analytics Form"):
444
+ st.write("๐Ÿ”ฎ **Predictive Analytics**")
445
+ target = st.selectbox("Select the target variable for prediction:", options=st.session_state['data'].columns)
446
+ submit = st.form_submit_button("๐Ÿš€ Run Predictive Analysis")
447
+
448
+ if submit:
449
+ with st.spinner('๐Ÿš€ Performing predictive analysis...'):
450
+ result = st.session_state['agent'].run(
451
+ f"Use the predictive_analysis tool to build a classification model with `{target}` as the target variable."
452
+ )
453
+ st.markdown(result, unsafe_allow_html=True)
454
+ st.session_state['report_content'] += result + "\n\n"
455
+ export_report(result, "Predictive_Analysis_Report")
456
+
457
+ elif analysis_type == "Custom Question":
458
+ with st.expander("๐Ÿ“ **Ask a Custom Question**"):
459
+ question = st.text_input("What would you like to know about your data?")
460
+ if st.button("๐Ÿ” Get Answer"):
461
+ if question:
462
+ with st.spinner('๐Ÿง  Processing your question...'):
463
+ result = st.session_state['agent'].run(question)
464
+ st.markdown(result, unsafe_allow_html=True)
465
+ st.session_state['report_content'] += f"### Custom Question: {question}\n{result}\n\n"
466
+ else:
467
+ st.warning("Please enter a question.")
468
+
469
+ # Option to Export Report
470
+ if st.session_state['report_content']:
471
+ st.sidebar.markdown("---")
472
+ if st.sidebar.button("๐Ÿ“ค **Export Analysis Report**"):
473
+ export_report(st.session_state['report_content'], "Business_Intelligence_Report")
474
+ st.sidebar.success("โœ… Report exported successfully!")
475
+
476
+ except Exception as e:
477
+ st.error(f"โš ๏ธ An error occurred: {str(e)}")
478
+
479
+ # ------------------------------
480
+ # Application Entry Point
481
+ # ------------------------------
482
+ if __name__ == "__main__":
483
+ main()