Spaces:
Sleeping
Sleeping
Mustehson
commited on
Commit
·
99c2740
1
Parent(s):
eab6d7f
Added Text to Validation
Browse files
app.py
CHANGED
@@ -7,7 +7,9 @@ import pandera as pa
|
|
7 |
from pandera import Column
|
8 |
import ydata_profiling as pp
|
9 |
from huggingface_hub import InferenceClient
|
10 |
-
from prompt import PROMPT_PANDERA
|
|
|
|
|
11 |
|
12 |
|
13 |
# Height of the Tabs Text Area
|
@@ -17,6 +19,7 @@ md_token = os.getenv('MD_TOKEN')
|
|
17 |
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
|
18 |
|
19 |
|
|
|
20 |
INPUT_PROMPT = '''
|
21 |
Here is the frist few samples of data:
|
22 |
<Sample Data>
|
@@ -25,6 +28,19 @@ Here is the frist few samples of data:
|
|
25 |
'''
|
26 |
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
print('Connecting to DB...')
|
29 |
# Connect to DB
|
30 |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
|
@@ -53,11 +69,23 @@ def get_data_df(schema):
|
|
53 |
print('Getting Dataframe from the Database')
|
54 |
return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df()
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
60 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
try:
|
62 |
response = client.chat_completion(messages, max_tokens=1024)
|
63 |
print(response.choices[0].message.content)
|
@@ -149,8 +177,10 @@ def main(table):
|
|
149 |
df = get_data_df(schema)
|
150 |
df_statistics, df_alerts = statistics(df)
|
151 |
describe_num, describe_cat = describe(df)
|
|
|
|
|
152 |
|
153 |
-
tests = run_llm(
|
154 |
print(tests)
|
155 |
if isinstance(tests, Exception):
|
156 |
tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}])
|
@@ -162,10 +192,34 @@ def main(table):
|
|
162 |
|
163 |
return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests_df, pandera_results
|
164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
# Custom CSS styling
|
166 |
custom_css = """
|
|
|
167 |
.gradio-container {
|
168 |
background-color: #f0f4f8;
|
|
|
169 |
}
|
170 |
.logo {
|
171 |
max-width: 200px;
|
@@ -196,7 +250,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
|
|
196 |
schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
|
197 |
tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
|
198 |
with gr.Row():
|
199 |
-
|
200 |
|
201 |
with gr.Column(scale=2):
|
202 |
with gr.Tabs():
|
@@ -220,11 +274,26 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
|
|
220 |
|
221 |
with gr.Tab("Data"):
|
222 |
result_output = gr.DataFrame(label="Dataframe (10 Rows)", value=[], interactive=False)
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
schema_dropdown.change(update_table_names, inputs=schema_dropdown, outputs=tables_dropdown)
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
|
229 |
if __name__ == "__main__":
|
230 |
demo.launch(debug=True)
|
|
|
|
7 |
from pandera import Column
|
8 |
import ydata_profiling as pp
|
9 |
from huggingface_hub import InferenceClient
|
10 |
+
from prompt import PROMPT_PANDERA, PANDERA_USER_INPUT_PROMPT
|
11 |
+
import warnings
|
12 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
13 |
|
14 |
|
15 |
# Height of the Tabs Text Area
|
|
|
19 |
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
|
20 |
|
21 |
|
22 |
+
|
23 |
INPUT_PROMPT = '''
|
24 |
Here is the frist few samples of data:
|
25 |
<Sample Data>
|
|
|
28 |
'''
|
29 |
|
30 |
|
31 |
+
USER_INPUT = '''
|
32 |
+
Here is the frist few samples of data:
|
33 |
+
<Sample Data>
|
34 |
+
{data}
|
35 |
+
</Sample Data<>
|
36 |
+
|
37 |
+
Here is the User Description:
|
38 |
+
<User Description>
|
39 |
+
{user_description}
|
40 |
+
</User Description>
|
41 |
+
'''
|
42 |
+
|
43 |
+
|
44 |
print('Connecting to DB...')
|
45 |
# Connect to DB
|
46 |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
|
|
|
69 |
print('Getting Dataframe from the Database')
|
70 |
return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df()
|
71 |
|
72 |
+
|
73 |
+
def chat_template(system_prompt, user_prompt, df):
|
74 |
+
|
75 |
+
messages=[
|
76 |
+
{"role": "system", "content": system_prompt},
|
77 |
+
{"role": "user", "content": user_prompt.format(data=df.head().to_json(orient='records'))},
|
78 |
]
|
79 |
+
return messages
|
80 |
+
def chat_template_user(system_prompt, user_prompt, user_description, df):
|
81 |
+
|
82 |
+
messages=[
|
83 |
+
{"role": "system", "content": system_prompt},
|
84 |
+
{"role": "user", "content": user_prompt.format(data=df.head(1).to_json(orient='records'), user_description=user_description)},
|
85 |
+
]
|
86 |
+
return messages
|
87 |
+
|
88 |
+
def run_llm(messages):
|
89 |
try:
|
90 |
response = client.chat_completion(messages, max_tokens=1024)
|
91 |
print(response.choices[0].message.content)
|
|
|
177 |
df = get_data_df(schema)
|
178 |
df_statistics, df_alerts = statistics(df)
|
179 |
describe_num, describe_cat = describe(df)
|
180 |
+
|
181 |
+
messages = chat_template(system_prompt=PROMPT_PANDERA, user_prompt=INPUT_PROMPT, df=df)
|
182 |
|
183 |
+
tests = run_llm(messages)
|
184 |
print(tests)
|
185 |
if isinstance(tests, Exception):
|
186 |
tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}])
|
|
|
192 |
|
193 |
return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests_df, pandera_results
|
194 |
|
195 |
+
def user_results(table, text_query):
|
196 |
+
|
197 |
+
schema = get_table_schema(table)
|
198 |
+
df = get_data_df(schema)
|
199 |
+
|
200 |
+
messages = chat_template_user(system_prompt=PANDERA_USER_INPUT_PROMPT,
|
201 |
+
user_prompt=USER_INPUT, user_description=text_query,
|
202 |
+
df=df)
|
203 |
+
print(messages)
|
204 |
+
tests = run_llm(messages)
|
205 |
+
print(f'Generated Tests from user input: {tests}')
|
206 |
+
|
207 |
+
if isinstance(tests, Exception):
|
208 |
+
tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}])
|
209 |
+
return tests, pd.DataFrame([])
|
210 |
+
|
211 |
+
tests_df = pd.DataFrame(tests)
|
212 |
+
tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True)
|
213 |
+
pandera_results = validate_pandera(tests, df)
|
214 |
+
|
215 |
+
return tests_df, pandera_results
|
216 |
+
|
217 |
# Custom CSS styling
|
218 |
custom_css = """
|
219 |
+
print('Validated Tests with Pandera')
|
220 |
.gradio-container {
|
221 |
background-color: #f0f4f8;
|
222 |
+
|
223 |
}
|
224 |
.logo {
|
225 |
max-width: 200px;
|
|
|
250 |
schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
|
251 |
tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
|
252 |
with gr.Row():
|
253 |
+
generate_result = gr.Button("Validate Data", variant="primary")
|
254 |
|
255 |
with gr.Column(scale=2):
|
256 |
with gr.Tabs():
|
|
|
274 |
|
275 |
with gr.Tab("Data"):
|
276 |
result_output = gr.DataFrame(label="Dataframe (10 Rows)", value=[], interactive=False)
|
277 |
+
|
278 |
+
with gr.Tab('Text to Validation'):
|
279 |
+
with gr.Row():
|
280 |
+
query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter Text Query to Generate Validation e.g. Validate that the incident_zip column contains valid 5-digit ZIP codes.")
|
281 |
+
with gr.Row():
|
282 |
+
with gr.Column():
|
283 |
+
pass
|
284 |
+
with gr.Column(scale=1, min_width=50):
|
285 |
+
user_generate_result = gr.Button("Validate Data", variant="primary" )
|
286 |
+
|
287 |
+
with gr.Row():
|
288 |
+
with gr.Column():
|
289 |
+
query_tests = gr.DataFrame(label="Validation Rules", value=[], interactive=False)
|
290 |
+
with gr.Column():
|
291 |
+
query_result = gr.DataFrame(label="Validation Result", value=[], interactive=False)
|
292 |
+
|
293 |
schema_dropdown.change(update_table_names, inputs=schema_dropdown, outputs=tables_dropdown)
|
294 |
+
generate_result.click(main, inputs=[tables_dropdown], outputs=[result_output, data_description, data_alerts, describe_cat, describe_num, tests_output, test_result_output])
|
295 |
+
user_generate_result.click(user_results, inputs=[tables_dropdown, query_input], outputs=[query_tests, query_result])
|
|
|
296 |
|
297 |
if __name__ == "__main__":
|
298 |
demo.launch(debug=True)
|
299 |
+
|
prompt.py
CHANGED
@@ -66,3 +66,56 @@ Return the final rules as a single JSON object, ensuring that each column is tho
|
|
66 |
|
67 |
DO NOT OUTPUT ANYTHING OR ANY EXPLAINATION OTHER THAN JSON OBJECT
|
68 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
DO NOT OUTPUT ANYTHING OR ANY EXPLAINATION OTHER THAN JSON OBJECT
|
68 |
"""
|
69 |
+
|
70 |
+
|
71 |
+
PANDERA_USER_INPUT_PROMPT = """
|
72 |
+
You are a data quality engineer. Your role is to assist the user in creating deterministic rules to validate the quality of a dataset using **Pandera**.
|
73 |
+
You will be provided with the first few rows of data below that represents the dataset for which you need to help the user create validation rules. Please note that this is only a sample of the data, and there may be additional rows and categorical columns that are not fully represented in the sample. Keep in mind that the sample may not cover all possible values, but the validation rules must handle all data in the dataset.
|
74 |
+
|
75 |
+
Follow this process:
|
76 |
+
|
77 |
+
1. **Observe the sample data.**
|
78 |
+
2. Observe description and create a valid pander check
|
79 |
+
|
80 |
+
Here are the valid **Pandera** Checks that you can use:
|
81 |
+
1. 'pa.Check.between(min_value, max_value, include_min=True, include_max=True, **kwargs)'
|
82 |
+
2. 'pa.Check.eq(value, **kwargs)' Checks if a value is equal to the specified value.
|
83 |
+
3. 'pa.Check.equal_to(value, **kwargs)' Alias for eq(). Checks if a value is equal to the specified value.
|
84 |
+
4. 'pa.Check.ge(min_value, **kwargs)' Checks if a value is greater than or equal to the specified minimum value.
|
85 |
+
5. 'pa.Check.greater_than(min_value, **kwargs)' Checks if a value is strictly greater than the specified minimum value.
|
86 |
+
6. 'pa.Check.greater_than_or_equal_to(min_value, **kwargs)' Checks if a value is greater than or equal to the specified minimum value.
|
87 |
+
7. 'pa.Check.gt(min_value, **kwargs)' Alias for greater_than(). Checks if a value is strictly greater than the specified minimum value.
|
88 |
+
8. 'pa.Check.in_range(min_value, max_value, include_min=True, include_max=True, **kwargs)' Checks if a value is within the specified range. By default, it's inclusive of both min and max values.
|
89 |
+
9. 'pa.Check.isin(allowed_values, **kwargs)' Checks if a value is in the set of allowed values.
|
90 |
+
10. 'pa.Check.le(max_value, **kwargs)' Checks if a value is less than or equal to the specified maximum value.
|
91 |
+
11. 'pa.Check.less_than(max_value, **kwargs)' ): Checks if a value is strictly less than the specified maximum value.
|
92 |
+
12. 'pa.Check.less_than_or_equal_to(max_value, **kwargs)' Checks if a value is less than or equal to the specified maximum value.
|
93 |
+
13. 'pa.Check.lt(max_value, **kwargs)' Checks if a value is strictly less than the specified maximum value.
|
94 |
+
14. 'pa.Check.ne(value, **kwargs)' Checks if a value is not equal to the specified value.
|
95 |
+
15. 'pa.Check.not_equal_to(value, **kwargs)' Checks if a value is not equal to the specified value.
|
96 |
+
16. 'pa.Check.notin(forbidden_values, **kwargs)' Checks if a value is not in the set of forbidden values.
|
97 |
+
17. 'pa.Check.str_contains(pattern, **kwargs)' Checks if a string contains the specified pattern.
|
98 |
+
18. 'pa.Check.str_endswith(string, **kwargs)' Checks if a string ends with the specified substring.
|
99 |
+
19. 'pa.Check.str_length(min_value=None, max_value=None, **kwargs)' Checks if the length of a string is within the specified range.
|
100 |
+
20. 'pa.Check.str_matches(pattern, **kwargs)' Checks if a string matches the specified regular expression pattern.
|
101 |
+
21. 'pa.Check.str_startswith(string, **kwargs)' Checks if a string starts with the specified substring.
|
102 |
+
22. 'pa.Check.unique_values_eq(values, **kwargs)' Checks if the unique values in a column are equal to the specified set of values.
|
103 |
+
23. 'pa.Check(lambda x: x )' with lambda functions for custom logic.
|
104 |
+
24. 'pa.Column(int, nullable=False, unique=True, name='column_name') For unqiue values
|
105 |
+
**ALWAY USE THE COMPLETE PANDERA SYNTAX
|
106 |
+
|
107 |
+
3. For each column, generate a **column name**, **rule name**, and a **Pandera rule** based on the user’s description. Example structure:
|
108 |
+
|
109 |
+
```json
|
110 |
+
[
|
111 |
+
{
|
112 |
+
"column_name": "OS",
|
113 |
+
"rule_name": "Allowed Operating Systems",
|
114 |
+
"pandera_rule": "Column(str, pa.Check.isin(['macOS', 'Windows', 'Linux']), nullable=False, name='OS')"
|
115 |
+
}
|
116 |
+
]
|
117 |
+
|
118 |
+
4. Repeat this process for a maximum of 5 columns or based on user input. Group all the rules into a single JSON object and return it.
|
119 |
+
IMPORTANT: You should only generate rules based on the user’s input for each column. Return the final rules as a single JSON object, ensuring that the user's instructions are reflected in the validations.
|
120 |
+
|
121 |
+
DO NOT RETURN ANYTHING OR ANY EXPLANATION OTHER THAN JSON """
|