Spaces:
Sleeping
Sleeping
Mustehson
commited on
Commit
·
6dda383
1
Parent(s):
d14334a
Refactoring & Logs
Browse files- app.py +20 -48
- prompt.py +0 -123
- requirements.txt +2 -1
app.py
CHANGED
@@ -6,10 +6,9 @@ import pandas as pd
|
|
6 |
import pandera as pa
|
7 |
from pandera import Column
|
8 |
import ydata_profiling as pp
|
9 |
-
from langchain_core.messages import HumanMessage, SystemMessage
|
10 |
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
|
11 |
-
from prompt import PROMPT_PANDERA, PANDERA_USER_INPUT_PROMPT
|
12 |
from langsmith import traceable
|
|
|
13 |
import warnings
|
14 |
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
15 |
|
@@ -18,29 +17,6 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
18 |
TAB_LINES = 8
|
19 |
# Load Token
|
20 |
md_token = os.getenv('MD_TOKEN')
|
21 |
-
|
22 |
-
INPUT_PROMPT = '''
|
23 |
-
Here are the first few samples of data:
|
24 |
-
<Sample Data>
|
25 |
-
{data}
|
26 |
-
</Sample Data<>
|
27 |
-
'''
|
28 |
-
|
29 |
-
|
30 |
-
USER_INPUT = '''
|
31 |
-
Here are the first few samples of data:
|
32 |
-
<Sample Data>
|
33 |
-
{data}
|
34 |
-
</Sample Data<>
|
35 |
-
|
36 |
-
Here is the User Description:
|
37 |
-
<User Description>
|
38 |
-
{user_description}
|
39 |
-
</User Description>
|
40 |
-
'''
|
41 |
-
|
42 |
-
|
43 |
-
print('Connecting to DB...')
|
44 |
# Connect to DB
|
45 |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
|
46 |
|
@@ -60,6 +36,12 @@ for model in models:
|
|
60 |
|
61 |
llm = ChatHuggingFace(llm=endpoint).bind_tools(tools=[], max_tokens=8192)
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
# Get Databases
|
64 |
def get_schemas():
|
65 |
schemas = conn.execute("""
|
@@ -84,28 +66,20 @@ def get_data_df(schema):
|
|
84 |
return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df()
|
85 |
|
86 |
|
87 |
-
def
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
HumanMessage(content=user_prompt.format(data=df.head().to_json(orient='records'))),
|
92 |
-
]
|
93 |
-
return messages
|
94 |
-
|
95 |
-
def chat_template_user(system_prompt, user_prompt, user_description, df):
|
96 |
|
97 |
-
messages = [
|
98 |
-
SystemMessage(content=system_prompt),
|
99 |
-
HumanMessage(content=user_prompt.format(data=df.head(1).to_json(orient='records'), user_description=user_description)),
|
100 |
-
]
|
101 |
-
return messages
|
102 |
|
|
|
|
|
|
|
103 |
|
104 |
-
@traceable()
|
105 |
def run_llm(messages):
|
106 |
try:
|
107 |
response = llm.invoke(messages)
|
108 |
-
print(response.content)
|
109 |
tests = json.loads(response.content)
|
110 |
except Exception as e:
|
111 |
return e
|
@@ -199,11 +173,11 @@ def main(table):
|
|
199 |
df = get_data_df(schema)
|
200 |
df_statistics, df_alerts = statistics(df)
|
201 |
describe_num, describe_cat = describe(df)
|
202 |
-
|
203 |
-
messages =
|
204 |
-
|
205 |
tests = run_llm(messages)
|
206 |
print(tests)
|
|
|
207 |
if isinstance(tests, Exception):
|
208 |
tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}])
|
209 |
return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests, pd.DataFrame([])
|
@@ -219,11 +193,9 @@ def user_results(table, text_query):
|
|
219 |
schema = get_table_schema(table)
|
220 |
df = get_data_df(schema)
|
221 |
|
222 |
-
messages =
|
223 |
-
user_prompt=USER_INPUT, user_description=text_query,
|
224 |
-
df=df)
|
225 |
-
print(messages)
|
226 |
tests = run_llm(messages)
|
|
|
227 |
print(f'Generated Tests from user input: {tests}')
|
228 |
|
229 |
if isinstance(tests, Exception):
|
|
|
6 |
import pandera as pa
|
7 |
from pandera import Column
|
8 |
import ydata_profiling as pp
|
|
|
9 |
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
|
|
|
10 |
from langsmith import traceable
|
11 |
+
from langchain import hub
|
12 |
import warnings
|
13 |
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
14 |
|
|
|
17 |
TAB_LINES = 8
|
18 |
# Load Token
|
19 |
md_token = os.getenv('MD_TOKEN')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
# Connect to DB
|
21 |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
|
22 |
|
|
|
36 |
|
37 |
llm = ChatHuggingFace(llm=endpoint).bind_tools(tools=[], max_tokens=8192)
|
38 |
|
39 |
+
|
40 |
+
|
41 |
+
prompt_autogenerate = hub.pull("autogenerate-rules-testworkflow")
|
42 |
+
prompt_user_input = hub.pull("usergenerate-rules-testworkflow")
|
43 |
+
|
44 |
+
|
45 |
# Get Databases
|
46 |
def get_schemas():
|
47 |
schemas = conn.execute("""
|
|
|
66 |
return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df()
|
67 |
|
68 |
|
69 |
+
def format_prompt(df):
|
70 |
+
return prompt_autogenerate.format_prompt(data=df.head().to_json(orient='records'))
|
71 |
+
def format_user_prompt(df, user_description):
|
72 |
+
return prompt_user_input.format_prompt(data=df.head(2).to_json(orient='records'), user_description=user_description)
|
|
|
|
|
|
|
|
|
|
|
73 |
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
+
def process_inputs(inputs) :
|
76 |
+
print(inputs)
|
77 |
+
return {'input_query': inputs['messages'].to_messages()[1]}
|
78 |
|
79 |
+
@traceable(process_inputs=process_inputs)
|
80 |
def run_llm(messages):
|
81 |
try:
|
82 |
response = llm.invoke(messages)
|
|
|
83 |
tests = json.loads(response.content)
|
84 |
except Exception as e:
|
85 |
return e
|
|
|
173 |
df = get_data_df(schema)
|
174 |
df_statistics, df_alerts = statistics(df)
|
175 |
describe_num, describe_cat = describe(df)
|
176 |
+
|
177 |
+
messages = format_prompt(df=df)
|
|
|
178 |
tests = run_llm(messages)
|
179 |
print(tests)
|
180 |
+
|
181 |
if isinstance(tests, Exception):
|
182 |
tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}])
|
183 |
return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests, pd.DataFrame([])
|
|
|
193 |
schema = get_table_schema(table)
|
194 |
df = get_data_df(schema)
|
195 |
|
196 |
+
messages = format_user_prompt(df=df, user_description=text_query)
|
|
|
|
|
|
|
197 |
tests = run_llm(messages)
|
198 |
+
|
199 |
print(f'Generated Tests from user input: {tests}')
|
200 |
|
201 |
if isinstance(tests, Exception):
|
prompt.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
PROMPT_PANDERA = """
|
2 |
-
You are a data quality engineer. Your role is to create deterministic rules to validate the quality of a dataset using **Pandera**.
|
3 |
-
You will be provided with the first few rows of data below that represents the dataset for which you need to create validation rules. Please note that this is only a sample of the data, and there may be additional rows and categorical columns that are not fully represented in the sample. Keep in mind that the sample may not cover all possible values, but the validation rules must handle all data in the dataset.
|
4 |
-
|
5 |
-
Follow this process:
|
6 |
-
|
7 |
-
1. **Observe the sample data.**
|
8 |
-
2. **For each column**, create a validation rule using Pandera syntax.
|
9 |
-
Here are the valid pandera check class methods DO NOT USE ANYOTHER METHODS OTHER THAN THE BELOW GIVEN METHODS:
|
10 |
-
DO NOT USE SINGLE backslashes \ BUT USE DOUBLE backslashes \\ IN PATTERN
|
11 |
-
USE CORRECT SYNTAX AS SHOWN GIVEN BELOW
|
12 |
-
[
|
13 |
-
'pa.Check.between(min_value, max_value, include_min=True, include_max=True, **kwargs)',
|
14 |
-
'pa.Check.eq(value, **kwargs)',
|
15 |
-
'pa.Check.equal_to(value, **kwargs)',
|
16 |
-
'pa.Check.ge(min_value, **kwargs)',
|
17 |
-
'pa.Check.greater_than(min_value, **kwargs)',
|
18 |
-
'pa.Check.greater_than_or_equal_to(min_value, **kwargs)',
|
19 |
-
'pa.Check.gt(min_value, **kwargs)',
|
20 |
-
'pa.Check.in_range(min_value, max_value, include_min=True, include_max=True, **kwargs)',
|
21 |
-
'pa.Check.isin(allowed_values, **kwargs)',
|
22 |
-
'pa.Check.le(max_value, **kwargs)',
|
23 |
-
'pa.Check.less_than(max_value, **kwargs)',
|
24 |
-
'pa.Check.less_than_or_equal_to(max_value, **kwargs)',
|
25 |
-
'pa.Check.lt(max_value, **kwargs)',
|
26 |
-
'pa.Check.ne(value, **kwargs)',
|
27 |
-
'pa.Check.not_equal_to(value, **kwargs)',
|
28 |
-
'pa.Check.notin(forbidden_values, **kwargs)',
|
29 |
-
'pa.Check.str_contains(pattern, **kwargs)',
|
30 |
-
'pa.Check.str_endswith(string, **kwargs)',
|
31 |
-
'pa.Check.str_length(min_value=None, max_value=None, **kwargs)',
|
32 |
-
'pa.Check.str_matches(pattern, **kwargs)',
|
33 |
-
'pa.Check.str_startswith(string, **kwargs)',
|
34 |
-
'pa.Check.unique_values_eq(values, **kwargs)'
|
35 |
-
]
|
36 |
-
|
37 |
-
|
38 |
-
3. Ensure that each rule specifies the expected data type and applies necessary checks such as:
|
39 |
-
name argument should be a valid column name. DO NOT USE ANYOTHER PANDERA
|
40 |
-
- **Data Type Validation** (e.g., `pa.Column(int, nullable=False, name="age")` ensures integers)
|
41 |
-
- **Non-null Check** (e.g., `pa.Column(str, nullable=False, name="name")` to ensure no nulls are allowed)
|
42 |
-
- **Unique Value Check** (e.g., `pa.Column(int, unique=True, name="ID")` for uniqueness)
|
43 |
-
- **Range or Bound Checks** (e.g., `pa.Column(float, checks=pa.Check.in_range(min_value=0, max_value=100), name="score")` for numerical ranges)
|
44 |
-
- **Allowed Value Checks** (e.g., `pa.Column(str, checks=pa.Check.isin([value1, value2]), name="gender")` to restrict values to a set)
|
45 |
-
- **Custom Validation Logic** using `pa.Column(int, checks=pa.Check(lambda x: x % 2 == 0), name="even_number")` with lambda functions (e.g., custom logic for even numbers or string patterns)
|
46 |
-
FOR DATETIME OR DATE COLUMN USE THE BELOW VALIDATION DO NOT CONISER IT AS INT OR FLOAT
|
47 |
-
- **DateTime or Date Validation** (e.g., `pa.Column(pa.dtypes.Timestamp, nullable=False), name="date_column")` to ensure dates or datetime)
|
48 |
-
|
49 |
-
For each column, provide a **column name**, **rule name** and a pandera_rule. Example structure (It should be list of dicts):
|
50 |
-
|
51 |
-
|
52 |
-
[
|
53 |
-
{
|
54 |
-
"column_name": "age",
|
55 |
-
"rule_name": "Ensure Column is Integer",
|
56 |
-
"pandera_rule": "Column(int, nullable=False, name='age')"
|
57 |
-
},
|
58 |
-
{
|
59 |
-
"column_name": "ID",
|
60 |
-
"rule_name": "Unique Identifier Check",
|
61 |
-
"pandera_rule": "Column(int, unique=True, name='ID')"
|
62 |
-
}
|
63 |
-
]
|
64 |
-
|
65 |
-
|
66 |
-
3 Repeat this process for max 5 columns in the dataset. If the data is less than 5 columns than include all columns. Group all the rules into a single JSON object and ensure that there is at least one validation rule for each column.
|
67 |
-
Return the final rules as a single JSON object, ensuring that each column is thoroughly validated based on the observations of the sample data.
|
68 |
-
|
69 |
-
DO NOT OUTPUT ANYTHING OR ANY EXPLAINATION OTHER THAN JSON OBJECT
|
70 |
-
"""
|
71 |
-
|
72 |
-
|
73 |
-
PANDERA_USER_INPUT_PROMPT = """
|
74 |
-
You are a data quality engineer. Your role is to assist the user in creating deterministic rules to validate the quality of a dataset using **Pandera**.
|
75 |
-
You will be provided with the first few rows of data below that represents the dataset for which you need to help the user create validation rules. Please note that this is only a sample of the data, and there may be additional rows and categorical columns that are not fully represented in the sample. Keep in mind that the sample may not cover all possible values, but the validation rules must handle all data in the dataset.
|
76 |
-
|
77 |
-
Follow this process:
|
78 |
-
|
79 |
-
1. **Observe the sample data.**
|
80 |
-
2. Observe description and create a valid check
|
81 |
-
|
82 |
-
Here are the valid **Pandera** Checks that you can use:
|
83 |
-
1. 'pa.Check.between(min_value, max_value, include_min=True, include_max=True, **kwargs)'
|
84 |
-
2. 'pa.Check.eq(value, **kwargs)' Checks if a value is equal to the specified value.
|
85 |
-
3. 'pa.Check.equal_to(value, **kwargs)' Alias for eq(). Checks if a value is equal to the specified value.
|
86 |
-
4. 'pa.Check.ge(min_value, **kwargs)' Checks if a value is greater than or equal to the specified minimum value.
|
87 |
-
5. 'pa.Check.greater_than(min_value, **kwargs)' Checks if a value is strictly greater than the specified minimum value.
|
88 |
-
6. 'pa.Check.greater_than_or_equal_to(min_value, **kwargs)' Checks if a value is greater than or equal to the specified minimum value.
|
89 |
-
7. 'pa.Check.gt(min_value, **kwargs)' Alias for greater_than(). Checks if a value is strictly greater than the specified minimum value.
|
90 |
-
8. 'pa.Check.in_range(min_value, max_value, include_min=True, include_max=True, **kwargs)' Checks if a value is within the specified range. By default, it's inclusive of both min and max values.
|
91 |
-
9. 'pa.Check.isin(allowed_values, **kwargs)' Checks if a value is in the set of allowed values.
|
92 |
-
10. 'pa.Check.le(max_value, **kwargs)' Checks if a value is less than or equal to the specified maximum value.
|
93 |
-
11. 'pa.Check.less_than(max_value, **kwargs)' ): Checks if a value is strictly less than the specified maximum value.
|
94 |
-
12. 'pa.Check.less_than_or_equal_to(max_value, **kwargs)' Checks if a value is less than or equal to the specified maximum value.
|
95 |
-
13. 'pa.Check.lt(max_value, **kwargs)' Checks if a value is strictly less than the specified maximum value.
|
96 |
-
14. 'pa.Check.ne(value, **kwargs)' Checks if a value is not equal to the specified value.
|
97 |
-
15. 'pa.Check.not_equal_to(value, **kwargs)' Checks if a value is not equal to the specified value.
|
98 |
-
16. 'pa.Check.notin(forbidden_values, **kwargs)' Checks if a value is not in the set of forbidden values.
|
99 |
-
17. 'pa.Check.str_contains(pattern, **kwargs)' Checks if a string contains the specified pattern.
|
100 |
-
18. 'pa.Check.str_endswith(string, **kwargs)' Checks if a string ends with the specified substring.
|
101 |
-
19. 'pa.Check.str_length(min_value=None, max_value=None, **kwargs)' Checks if the length of a string is within the specified range.
|
102 |
-
20. 'pa.Check.str_matches(pattern, **kwargs)' Checks if a string matches the specified regular expression pattern.
|
103 |
-
21. 'pa.Check.str_startswith(string, **kwargs)' Checks if a string starts with the specified substring.
|
104 |
-
22. 'pa.Check.unique_values_eq(values, **kwargs)' Checks if the unique values in a column are equal to the specified set of values.
|
105 |
-
23. 'pa.Check(lambda x: x )' with lambda functions for custom logic.
|
106 |
-
**ALWAY USE THE COMPLETE PANDERA SYNTAX**
|
107 |
-
|
108 |
-
3. For each column, generate a **column name**, **rule name**, and a **Pandera rule** based on the user’s description. Example structure (It should be list of dicts):
|
109 |
-
|
110 |
-
|
111 |
-
[
|
112 |
-
{
|
113 |
-
"column_name": "unique_key",
|
114 |
-
"rule_name": "Unique Identifiers",
|
115 |
-
"pandera_rule": "pa.Column(int, nullable=False, unique=True, name='unique_key')"
|
116 |
-
}
|
117 |
-
]
|
118 |
-
|
119 |
-
|
120 |
-
4. Repeat this process for a maximum of 5 columns or based on user input. Group all the rules into a single JSON object and return it.
|
121 |
-
IMPORTANT: You should only generate rules based on the user’s input for each column. Return the final rules as a single JSON object, ensuring that the user's instructions are reflected in the validations.
|
122 |
-
|
123 |
-
DO NOT RETURN ANYTHING OR ANY EXPLANATION OTHER THAN JSON """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -7,4 +7,5 @@ duckdb==1.1.1
|
|
7 |
langsmith==0.1.135
|
8 |
pandera==0.20.4
|
9 |
ydata-profiling==v4.11.0
|
10 |
-
langchain-core==0.3.12
|
|
|
|
7 |
langsmith==0.1.135
|
8 |
pandera==0.20.4
|
9 |
ydata-profiling==v4.11.0
|
10 |
+
langchain-core==0.3.12
|
11 |
+
langchain==0.3.4
|