Mustehson commited on
Commit
0e13b2c
·
1 Parent(s): 7c2e7ac
Files changed (3) hide show
  1. app.py +73 -83
  2. requirements.txt +2 -1
  3. utils.py +162 -0
app.py CHANGED
@@ -5,8 +5,13 @@ import gradio as gr
5
  import pandas as pd
6
  import pandera as pa
7
  from pandera import Column
8
- import ydata_profiling as pp
 
9
  from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
 
 
 
 
10
  from langsmith import traceable
11
  from langchain import hub
12
  import warnings
@@ -38,7 +43,7 @@ for model in models:
38
  print(f"Error for model {model}: {e}")
39
  continue
40
 
41
- llm = ChatHuggingFace(llm=endpoint).bind_tools(tools=[], max_tokens=8192)
42
  #---------------------------------------
43
 
44
  #-----LOAD PROMPT FROM LANCHAIN HUB-----
@@ -65,37 +70,80 @@ def update_table_names(schema_name):
65
  tables = get_tables_names(schema_name)
66
  return gr.update(choices=tables)
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def get_data_df(schema):
69
  print('Getting Dataframe from the Database')
70
  return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df()
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def df_summary(df):
73
  summary = []
74
 
75
  for column in df.columns:
76
  if pd.api.types.is_numeric_dtype(df[column]):
77
  summary.append({
78
- "column": column,
79
- "max": df[column].max(),
80
- "min": df[column].min(),
81
- "count": df[column].count(),
82
- "nunique": df[column].nunique(),
83
- "dtype": str(df[column].dtype),
84
- "top": None
85
  })
86
 
87
  elif pd.api.types.is_categorical_dtype(df[column]) or pd.api.types.is_object_dtype(df[column]):
88
  top_value = df[column].mode().iloc[0] if not df[column].mode().empty else None
89
 
90
  summary.append({
91
- "column": column,
92
- "max": None,
93
- "min": None,
94
- "count": df[column].count(),
95
- "nunique": df[column].nunique(),
96
- "dtype": str(df[column].dtype),
97
- "top": top_value
98
  })
 
99
  summary_df = pd.DataFrame(summary)
100
  return summary_df.reset_index(drop=True)
101
 
@@ -119,33 +167,6 @@ def run_llm(messages):
119
  return tests
120
 
121
 
122
- # Get Schema
123
- def get_table_schema(table):
124
- result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df()
125
- ddl_create = result.iloc[0,0]
126
- parent_database = result.iloc[0,1]
127
- schema_name = result.iloc[0,2]
128
- full_path = f"{parent_database}.{schema_name}.{table}"
129
- if schema_name != "main":
130
- old_path = f"{schema_name}.{table}"
131
- else:
132
- old_path = table
133
- ddl_create = ddl_create.replace(old_path, full_path)
134
- return full_path
135
-
136
- def describe(df):
137
-
138
- numerical_info = pd.DataFrame()
139
- categorical_info = pd.DataFrame()
140
- if len(df.select_dtypes(include=['number']).columns) >= 1:
141
- numerical_info = df.select_dtypes(include=['number']).describe().T.reset_index()
142
- numerical_info.rename(columns={'index': 'column'}, inplace=True)
143
- if len(df.select_dtypes(include=['object']).columns) >= 1:
144
- categorical_info = df.select_dtypes(include=['object']).describe().T.reset_index()
145
- categorical_info.rename(columns={'index': 'column'}, inplace=True)
146
-
147
- return numerical_info, categorical_info
148
-
149
  def validate_pandera(tests, df):
150
  validation_results = []
151
 
@@ -165,41 +186,6 @@ def validate_pandera(tests, df):
165
  })
166
  return pd.DataFrame(validation_results)
167
 
168
- def statistics(df):
169
- profile = pp.ProfileReport(df)
170
- report_dict = profile.get_description()
171
- description, alerts = report_dict.table, report_dict.alerts
172
- # Statistics
173
- mapping = {
174
- 'n': 'Number of observations',
175
- 'n_var': 'Number of variables',
176
- 'n_cells_missing': 'Number of cells missing',
177
- 'n_vars_with_missing': 'Number of columns with missing data',
178
- 'n_vars_all_missing': 'Columns with all missing data',
179
- 'p_cells_missing': 'Missing cells (%)',
180
- 'n_duplicates': 'Duplicated rows',
181
- 'p_duplicates': 'Duplicated rows (%)',
182
- }
183
-
184
- updated_data = {mapping.get(k, k): v for k, v in description.items() if k != 'types'}
185
- # Add flattened types information
186
- if 'Text' in description.get('types', {}):
187
- updated_data['Number of text columns'] = description['types']['Text']
188
- if 'Categorical' in description.get('types', {}):
189
- updated_data['Number of categorical columns'] = description['types']['Categorical']
190
- if 'Numeric' in description.get('types', {}):
191
- updated_data['Number of numeric columns'] = description['types']['Numeric']
192
- if 'DateTime' in description.get('types', {}):
193
- updated_data['Number of datetime columns'] = description['types']['DateTime']
194
-
195
- df_statistics = pd.DataFrame(list(updated_data.items()), columns=['Statistic Description', 'Value'])
196
- df_statistics['Value'] = df_statistics['Value'].astype(int)
197
-
198
- # Alerts
199
- alerts_list = [(str(alert).replace('[', '').replace(']', ''), alert.alert_type_name) for alert in alerts]
200
- df_alerts = pd.DataFrame(alerts_list, columns=['Data Quality Issue', 'Category'])
201
-
202
- return df_statistics, df_alerts
203
  #---------------------------------------
204
 
205
 
@@ -208,22 +194,26 @@ def statistics(df):
208
  def main(table):
209
  schema = get_table_schema(table)
210
  df = get_data_df(schema)
211
- df_statistics, df_alerts = statistics(df)
212
- describe_num, describe_cat = describe(df)
213
-
214
  messages = format_prompt(df=df)
215
  tests = run_llm(messages)
216
  print(tests)
217
 
 
 
 
 
 
 
218
  if isinstance(tests, Exception):
219
  tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}])
220
- return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests, pd.DataFrame([])
221
 
222
  tests_df = pd.DataFrame(tests)
223
  tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True)
224
  pandera_results = validate_pandera(tests, df)
225
 
226
- return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests_df, pandera_results
227
 
228
  def user_results(table, text_query):
229
 
 
5
  import pandas as pd
6
  import pandera as pa
7
  from pandera import Column
8
+ import random
9
+ from dataprep.eda import compute
10
  from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
11
+ from .utils import (
12
+ format_num_stats, format_cat_stats,
13
+ format_ov_stats, format_insights
14
+ )
15
  from langsmith import traceable
16
  from langchain import hub
17
  import warnings
 
43
  print(f"Error for model {model}: {e}")
44
  continue
45
 
46
+ llm = ChatHuggingFace(llm=endpoint).bind(max_tokens=4096)
47
  #---------------------------------------
48
 
49
  #-----LOAD PROMPT FROM LANCHAIN HUB-----
 
70
  tables = get_tables_names(schema_name)
71
  return gr.update(choices=tables)
72
 
73
+ # Get Schema
74
+ def get_table_schema(table):
75
+ result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df()
76
+ ddl_create = result.iloc[0,0]
77
+ parent_database = result.iloc[0,1]
78
+ schema_name = result.iloc[0,2]
79
+ full_path = f"{parent_database}.{schema_name}.{table}"
80
+ if schema_name != "main":
81
+ old_path = f"{schema_name}.{table}"
82
+ else:
83
+ old_path = table
84
+ ddl_create = ddl_create.replace(old_path, full_path)
85
+ return full_path
86
+
87
  def get_data_df(schema):
88
  print('Getting Dataframe from the Database')
89
  return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df()
90
 
91
+ def calcualte_stats(df):
92
+ indev_stats = []
93
+ cols = []
94
+
95
+ _df = df.copy()
96
+
97
+ num_cols = _df.select_dtypes(include=['number'], exclude=['datetime']).columns
98
+ cat_cols = _df.select_dtypes(include=['object'], exclude=['datetime']).columns
99
+
100
+
101
+ _all_stats = compute(_df)
102
+ all_stats = format_ov_stats(_all_stats['stats'])
103
+ insights = format_insights(_all_stats['overview_insights'])
104
+
105
+ for i, col in enumerate(random.sample(num_cols.tolist()+cat_cols.tolist(), 2)):
106
+ _indv_data = compute(_df, col)
107
+
108
+ if col in cat_cols:
109
+ indev_data_cat = format_cat_stats(_indv_data["data"])
110
+
111
+ indev_stats.append(pd.DataFrame([indev_data_cat['Overview']], index=[f'{col}_stats']).T)
112
+
113
+ elif col in num_cols:
114
+ try:
115
+ indev_data_num = format_num_stats(_indv_data["data"])
116
+ except:
117
+ indev_data_num = format_cat_stats(_indv_data["data"])
118
+
119
+ indev_stats.append(pd.DataFrame([indev_data_num['Overview']], index=[f'{col}_stats']).T)
120
+
121
+ return {
122
+ "overall_stats": pd.DataFrame(all_stats[0], index=['Dataset Statistics']).T,
123
+ "insights": insights,
124
+ "stats_1": indev_stats[0],
125
+ "stats_2": indev_stats[1]
126
+ }
127
+
128
  def df_summary(df):
129
  summary = []
130
 
131
  for column in df.columns:
132
  if pd.api.types.is_numeric_dtype(df[column]):
133
  summary.append({
134
+ "column": column, "max": df[column].max(), "min": df[column].min(),
135
+ "count": df[column].count(), "nunique": df[column].nunique(),
136
+ "dtype": str(df[column].dtype), "top": None
 
 
 
 
137
  })
138
 
139
  elif pd.api.types.is_categorical_dtype(df[column]) or pd.api.types.is_object_dtype(df[column]):
140
  top_value = df[column].mode().iloc[0] if not df[column].mode().empty else None
141
 
142
  summary.append({
143
+ "column": column, "max": None, "min": None, "count": df[column].count(),
144
+ "nunique": df[column].nunique(), "dtype": str(df[column].dtype), "top": top_value
 
 
 
 
 
145
  })
146
+
147
  summary_df = pd.DataFrame(summary)
148
  return summary_df.reset_index(drop=True)
149
 
 
167
  return tests
168
 
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def validate_pandera(tests, df):
171
  validation_results = []
172
 
 
186
  })
187
  return pd.DataFrame(validation_results)
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  #---------------------------------------
190
 
191
 
 
194
  def main(table):
195
  schema = get_table_schema(table)
196
  df = get_data_df(schema)
197
+
 
 
198
  messages = format_prompt(df=df)
199
  tests = run_llm(messages)
200
  print(tests)
201
 
202
+ stats = calcualte_stats(df)
203
+ df_insights = stats['insights']
204
+ df_statistics = stats['overall_stats']
205
+ df_stat_1 = stats['stats_1']
206
+ df_stat_2 = stats['stats_2']
207
+
208
  if isinstance(tests, Exception):
209
  tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}])
210
+ return df.head(10), df_statistics, df_insights, df_stat_1, df_stat_2, tests, pd.DataFrame([])
211
 
212
  tests_df = pd.DataFrame(tests)
213
  tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True)
214
  pandera_results = validate_pandera(tests, df)
215
 
216
+ return df.head(10), df_statistics, df_insights, df_stat_1, df_stat_2, tests_df, pandera_results
217
 
218
  def user_results(table, text_query):
219
 
requirements.txt CHANGED
@@ -8,4 +8,5 @@ langsmith==0.1.135
8
  pandera==0.20.4
9
  ydata-profiling==v4.11.0
10
  langchain-core==0.3.12
11
- langchain==0.3.4
 
 
8
  pandera==0.20.4
9
  ydata-profiling==v4.11.0
10
  langchain-core==0.3.12
11
+ langchain==0.3.4
12
+ dataprep==0.4.4
utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import pandas as pd
4
+ # -----------------Numerical Statistics-----------------
5
+ def format_values(key, value):
6
+
7
+ if not isinstance(value, (int, float)):
8
+ # if value is a time
9
+ return str(value)
10
+
11
+ if "Memory" in key:
12
+ # for memory usage
13
+ ind = 0
14
+ unit = dict(enumerate(["B", "KB", "MB", "GB", "TB"], 0))
15
+ while value > 1024:
16
+ value /= 1024
17
+ ind += 1
18
+ return f"{value:.1f} {unit[ind]}"
19
+
20
+ if (value * 10) % 10 == 0:
21
+ # if value is int but in a float form with 0 at last digit
22
+ value = int(value)
23
+ if abs(value) >= 1000000:
24
+ return f"{value:.5g}"
25
+ elif abs(value) >= 1000000 or abs(value) < 0.001:
26
+ value = f"{value:.5g}"
27
+ elif abs(value) >= 1:
28
+ # eliminate trailing zeros
29
+ pre_value = float(f"{value:.4f}")
30
+ value = int(pre_value) if (pre_value * 10) % 10 == 0 else pre_value
31
+ elif 0.001 <= abs(value) < 1:
32
+ value = f"{value:.4g}"
33
+ else:
34
+ value = str(value)
35
+
36
+ if "%" in key:
37
+ # for percentage, only use digits before notation sign for extreme small number
38
+ value = f"{float(value):.1%}"
39
+ return str(value)
40
+
41
+ def format_num_stats(data):
42
+ """
43
+ Format numerical statistics
44
+ """
45
+ overview = {
46
+ "Approximate Distinct Count": data["nuniq"],
47
+ "Approximate Unique (%)": data["nuniq"] / data["npres"],
48
+ "Missing": data["nrows"] - data["npres"],
49
+ "Missing (%)": 1 - (data["npres"] / data["nrows"]),
50
+ "Infinite": (data["npres"] - data["nreals"]),
51
+ "Infinite (%)": (data["npres"] - data["nreals"]) / data["nrows"],
52
+ "Memory Size": data["mem_use"],
53
+ "Mean": data["mean"],
54
+ "Minimum": data["min"],
55
+ "Maximum": data["max"],
56
+ "Zeros": data["nzero"],
57
+ "Zeros (%)": data["nzero"] / data["nrows"],
58
+ "Negatives": data["nneg"],
59
+ "Negatives (%)": data["nneg"] / data["nrows"],
60
+ }
61
+ data["qntls"].index = np.round(data["qntls"].index, 2)
62
+ quantile = {
63
+ "Minimum": data["min"],
64
+ "5-th Percentile": data["qntls"].loc[0.05],
65
+ "Q1": data["qntls"].loc[0.25],
66
+ "Median": data["qntls"].loc[0.50],
67
+ "Q3": data["qntls"].loc[0.75],
68
+ "95-th Percentile": data["qntls"].loc[0.95],
69
+ "Maximum": data["max"],
70
+ "Range": data["max"] - data["min"],
71
+ "IQR": data["qntls"].loc[0.75] - data["qntls"].loc[0.25],
72
+ }
73
+ descriptive = {
74
+ "Mean": data["mean"],
75
+ "Standard Deviation": data["std"],
76
+ "Variance": data["std"] ** 2,
77
+ "Sum": data["mean"] * data["npres"],
78
+ "Skewness": float(data["skew"]),
79
+ "Kurtosis": float(data["kurt"]),
80
+ "Coefficient of Variation": data["std"] / data["mean"] if data["mean"] != 0 else np.nan,
81
+ }
82
+
83
+ # return {
84
+ # "Overview": {k: _format_values(k, v) for k, v in overview.items()},
85
+ # # "Quantile Statistics": {k: _format_values(k, v) for k, v in quantile.items()},
86
+ # # "Descriptive Statistics": {k: _format_values(k, v) for k, v in descriptive.items()},
87
+ # }
88
+
89
+ return {
90
+ "Overview": {**{k: format_values(k, v) for k, v in overview.items()},
91
+ **{k: format_values(k, v) for k, v in quantile.items()},
92
+ **{k: format_values(k, v) for k, v in descriptive.items()}}
93
+ }
94
+ # -----------------------------------------------------
95
+
96
+
97
+ # -----------------Categorical Statistics-----------------
98
+
99
+ def format_cat_stats(
100
+ data
101
+ ):
102
+ """
103
+ Format categorical statistics
104
+ """
105
+ stats = data['stats']
106
+ len_stats = data['len_stats']
107
+ letter_stats = data["letter_stats"]
108
+ ov_stats = {
109
+ "Approximate Distinct Count": stats["nuniq"],
110
+ "Approximate Unique (%)": stats["nuniq"] / stats["npres"],
111
+ "Missing": stats["nrows"] - stats["npres"],
112
+ "Missing (%)": 1 - stats["npres"] / stats["nrows"],
113
+ "Memory Size": stats["mem_use"],
114
+ }
115
+ sampled_rows = ("1st row", "2nd row", "3rd row", "4th row", "5th row")
116
+ smpl = dict(zip(sampled_rows, stats["first_rows"]))
117
+
118
+ # return {
119
+ # "Overview": {k: _format_values(k, v) for k, v in ov_stats.items()},
120
+ # "Length": {k: _format_values(k, v) for k, v in len_stats.items()},
121
+ # "Sample": {k: f"{v[:18]}..." if len(v) > 18 else v for k, v in smpl.items()},
122
+ # "Letter": {k: _format_values(k, v) for k, v in letter_stats.items()},
123
+ # }
124
+ return {
125
+ "Overview": {**{k: format_values(k, v) for k, v in ov_stats.items()},
126
+ **{k: format_values(k, v) for k, v in len_stats.items()},
127
+ }
128
+ }
129
+ # -----------------------------------------------------
130
+
131
+
132
+ def format_ov_stats(stats) :
133
+
134
+ nrows, ncols, npresent_cells, nrows_wo_dups, mem_use, dtypes_cnt = stats.values()
135
+ ncells = nrows * ncols
136
+
137
+ data = {
138
+ "Number of Variables": ncols,
139
+ "Number of Rows": nrows,
140
+ "Missing Cells": float(ncells - npresent_cells),
141
+ "Missing Cells (%)": 1 - (npresent_cells / ncells),
142
+ "Duplicate Rows": nrows - nrows_wo_dups,
143
+ "Duplicate Rows (%)": 1 - (nrows_wo_dups / nrows),
144
+ "Total Size in Memory": float(mem_use),
145
+ "Average Row Size in Memory": mem_use / nrows,
146
+ }
147
+ return {k: format_values(k, v) for k, v in data.items()}, dtypes_cnt
148
+
149
+
150
+ def format_insights(data):
151
+ data_list = []
152
+ for key, value_list in data.items():
153
+ for item in value_list:
154
+ for category, description in item.items():
155
+ data_list.append({'Category': category, 'Description': description})
156
+
157
+ insights_df = pd.DataFrame(data_list)
158
+
159
+ insights_df['Description'] = insights_df['Description'].str.replace(r'/\*start\*/', '', regex=True)
160
+ insights_df['Description'] = insights_df['Description'].str.replace(r'/\*end\*/', '', regex=True)
161
+
162
+ return insights_df