mgbam commited on
Commit
2311473
·
verified ·
1 Parent(s): 439fcbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -10
app.py CHANGED
@@ -26,12 +26,27 @@ import uuid # For generating unique report IDs
26
  # ------------------------------
27
  class GroqLLM:
28
  """Enhanced LLM interface with support for generating natural language summaries."""
29
- def __init__(self, model_name="llama-3.1-8B-Instant"):
 
 
 
 
 
 
 
30
  self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
31
  self.model_name = model_name
32
 
33
  def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
34
- """Make the class callable as required by smolagents"""
 
 
 
 
 
 
 
 
35
  try:
36
  # Handle different prompt formats
37
  if isinstance(prompt, (dict, list)):
@@ -63,18 +78,39 @@ class GroqLLM:
63
  # ------------------------------
64
  class DataAnalysisAgent(CodeAgent):
65
  """Extended CodeAgent with dataset awareness and predictive analytics capabilities."""
 
66
  def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
 
 
 
 
 
 
 
 
67
  super().__init__(*args, **kwargs)
68
  self._dataset = dataset
69
  self.models = {} # To store trained models
70
 
71
  @property
72
  def dataset(self) -> pd.DataFrame:
73
- """Access the stored dataset"""
 
 
 
 
74
  return self._dataset
75
 
76
  def run(self, prompt: str) -> str:
77
- """Override run method to include dataset context and support predictive tasks"""
 
 
 
 
 
 
 
 
78
  dataset_info = f"""
79
  Dataset Shape: {self.dataset.shape}
80
  Columns: {', '.join(self.dataset.columns)}
@@ -96,7 +132,22 @@ class DataAnalysisAgent(CodeAgent):
96
 
97
  @tool
98
  def analyze_basic_stats(data: pd.DataFrame) -> str:
99
- """Calculate and visualize basic statistical measures for numerical columns."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  if data is None:
101
  data = tool.agent.dataset
102
 
@@ -134,7 +185,21 @@ def analyze_basic_stats(data: pd.DataFrame) -> str:
134
 
135
  @tool
136
  def generate_correlation_matrix(data: pd.DataFrame) -> str:
137
- """Generate an interactive correlation matrix using Plotly."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  if data is None:
139
  data = tool.agent.dataset
140
 
@@ -156,7 +221,21 @@ def generate_correlation_matrix(data: pd.DataFrame) -> str:
156
 
157
  @tool
158
  def analyze_categorical_columns(data: pd.DataFrame) -> str:
159
- """Analyze categorical columns with visualizations."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  if data is None:
161
  data = tool.agent.dataset
162
 
@@ -197,7 +276,20 @@ def analyze_categorical_columns(data: pd.DataFrame) -> str:
197
 
198
  @tool
199
  def suggest_features(data: pd.DataFrame) -> str:
200
- """Suggest potential feature engineering steps based on data characteristics."""
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  if data is None:
202
  data = tool.agent.dataset
203
 
@@ -231,7 +323,21 @@ def suggest_features(data: pd.DataFrame) -> str:
231
 
232
  @tool
233
  def predictive_analysis(data: pd.DataFrame, target: str) -> str:
234
- """Perform predictive analytics by training a classification model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  if data is None:
236
  data = tool.agent.dataset
237
 
@@ -326,7 +432,19 @@ def predictive_analysis(data: pd.DataFrame, target: str) -> str:
326
  # Report Exporting Function
327
  # ------------------------------
328
  def export_report(content: str, filename: str):
329
- """Export the given content as a PDF report."""
 
 
 
 
 
 
 
 
 
 
 
 
330
  # Save content to a temporary HTML file
331
  with tempfile.NamedTemporaryFile(delete=False, suffix='.html') as tmp_file:
332
  tmp_file.write(content.encode('utf-8'))
 
26
  # ------------------------------
27
  class GroqLLM:
28
  """Enhanced LLM interface with support for generating natural language summaries."""
29
+
30
+ def __init__(self, model_name: str = "llama-3.1-8B-Instant"):
31
+ """
32
+ Initialize the GroqLLM with a specified model.
33
+
34
+ Args:
35
+ model_name (str): The name of the language model to use.
36
+ """
37
  self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
38
  self.model_name = model_name
39
 
40
  def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
41
+ """
42
+ Make the class callable as required by smolagents.
43
+
44
+ Args:
45
+ prompt (Union[str, dict, List[Dict]]): The input prompt for the language model.
46
+
47
+ Returns:
48
+ str: The generated response from the language model.
49
+ """
50
  try:
51
  # Handle different prompt formats
52
  if isinstance(prompt, (dict, list)):
 
78
  # ------------------------------
79
  class DataAnalysisAgent(CodeAgent):
80
  """Extended CodeAgent with dataset awareness and predictive analytics capabilities."""
81
+
82
  def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
83
+ """
84
+ Initialize the DataAnalysisAgent with the provided dataset.
85
+
86
+ Args:
87
+ dataset (pd.DataFrame): The dataset to analyze.
88
+ *args: Variable length argument list.
89
+ **kwargs: Arbitrary keyword arguments.
90
+ """
91
  super().__init__(*args, **kwargs)
92
  self._dataset = dataset
93
  self.models = {} # To store trained models
94
 
95
  @property
96
  def dataset(self) -> pd.DataFrame:
97
+ """Access the stored dataset.
98
+
99
+ Returns:
100
+ pd.DataFrame: The dataset stored in the agent.
101
+ """
102
  return self._dataset
103
 
104
  def run(self, prompt: str) -> str:
105
+ """
106
+ Override the run method to include dataset context and support predictive tasks.
107
+
108
+ Args:
109
+ prompt (str): The task prompt for analysis.
110
+
111
+ Returns:
112
+ str: The result of the analysis.
113
+ """
114
  dataset_info = f"""
115
  Dataset Shape: {self.dataset.shape}
116
  Columns: {', '.join(self.dataset.columns)}
 
132
 
133
  @tool
134
  def analyze_basic_stats(data: pd.DataFrame) -> str:
135
+ """
136
+ Calculate and visualize basic statistical measures for numerical columns.
137
+
138
+ This function computes fundamental statistical metrics including mean, median,
139
+ standard deviation, skewness, and counts of missing values for all numerical
140
+ columns in the provided DataFrame. It also generates a bar chart visualizing
141
+ the mean, median, and standard deviation for each numerical feature.
142
+
143
+ Args:
144
+ data (pd.DataFrame): A pandas DataFrame containing the dataset to analyze.
145
+ The DataFrame should contain at least one numerical column
146
+ for meaningful analysis.
147
+
148
+ Returns:
149
+ str: A markdown-formatted string containing the statistics and the generated plot.
150
+ """
151
  if data is None:
152
  data = tool.agent.dataset
153
 
 
185
 
186
  @tool
187
  def generate_correlation_matrix(data: pd.DataFrame) -> str:
188
+ """
189
+ Generate an interactive correlation matrix using Plotly.
190
+
191
+ This function creates an interactive heatmap visualization showing the correlations between
192
+ all numerical columns in the dataset. Users can hover over cells to see correlation values
193
+ and interact with the plot (zoom, pan).
194
+
195
+ Args:
196
+ data (pd.DataFrame): A pandas DataFrame containing the dataset to analyze.
197
+ The DataFrame should contain at least two numerical columns
198
+ for correlation analysis.
199
+
200
+ Returns:
201
+ str: An HTML string representing the interactive correlation matrix plot.
202
+ """
203
  if data is None:
204
  data = tool.agent.dataset
205
 
 
221
 
222
  @tool
223
  def analyze_categorical_columns(data: pd.DataFrame) -> str:
224
+ """
225
+ Analyze categorical columns with visualizations.
226
+
227
+ This function examines categorical columns to identify unique values, top categories,
228
+ and missing value counts. It also generates bar charts for the top 5 categories in each
229
+ categorical feature.
230
+
231
+ Args:
232
+ data (pd.DataFrame): A pandas DataFrame containing the dataset to analyze.
233
+ The DataFrame should contain at least one categorical column
234
+ for meaningful analysis.
235
+
236
+ Returns:
237
+ str: A markdown-formatted string containing analysis results and embedded plots.
238
+ """
239
  if data is None:
240
  data = tool.agent.dataset
241
 
 
276
 
277
  @tool
278
  def suggest_features(data: pd.DataFrame) -> str:
279
+ """
280
+ Suggest potential feature engineering steps based on data characteristics.
281
+
282
+ This function analyzes the dataset's structure and statistical properties to
283
+ recommend possible feature engineering steps that could improve model performance.
284
+
285
+ Args:
286
+ data (pd.DataFrame): A pandas DataFrame containing the dataset to analyze.
287
+ The DataFrame can contain both numerical and categorical columns.
288
+
289
+ Returns:
290
+ str: A string containing suggestions for feature engineering based on
291
+ the characteristics of the input data.
292
+ """
293
  if data is None:
294
  data = tool.agent.dataset
295
 
 
323
 
324
  @tool
325
  def predictive_analysis(data: pd.DataFrame, target: str) -> str:
326
+ """
327
+ Perform predictive analytics by training a classification model.
328
+
329
+ This function builds a classification model using Random Forest, evaluates its performance,
330
+ and provides detailed metrics and visualizations such as the confusion matrix and ROC curve.
331
+
332
+ Args:
333
+ data (pd.DataFrame): A pandas DataFrame containing the dataset to analyze.
334
+ The DataFrame should contain the target variable for prediction.
335
+ target (str): The name of the target variable column in the dataset.
336
+
337
+ Returns:
338
+ str: A markdown-formatted string containing the classification report, confusion matrix,
339
+ ROC curve, AUC score, and a unique Model ID.
340
+ """
341
  if data is None:
342
  data = tool.agent.dataset
343
 
 
432
  # Report Exporting Function
433
  # ------------------------------
434
  def export_report(content: str, filename: str):
435
+ """
436
+ Export the given content as a PDF report.
437
+
438
+ This function converts markdown content into a PDF file using pdfkit and provides
439
+ a download button for users to obtain the report.
440
+
441
+ Args:
442
+ content (str): The markdown content to be included in the PDF report.
443
+ filename (str): The desired name for the exported PDF file.
444
+
445
+ Returns:
446
+ None
447
+ """
448
  # Save content to a temporary HTML file
449
  with tempfile.NamedTemporaryFile(delete=False, suffix='.html') as tmp_file:
450
  tmp_file.write(content.encode('utf-8'))