Mustehson commited on
Commit
99c2740
·
1 Parent(s): eab6d7f

Added Text to Validation

Browse files
Files changed (2) hide show
  1. app.py +80 -11
  2. prompt.py +53 -0
app.py CHANGED
@@ -7,7 +7,9 @@ import pandera as pa
7
  from pandera import Column
8
  import ydata_profiling as pp
9
  from huggingface_hub import InferenceClient
10
- from prompt import PROMPT_PANDERA
 
 
11
 
12
 
13
  # Height of the Tabs Text Area
@@ -17,6 +19,7 @@ md_token = os.getenv('MD_TOKEN')
17
  os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
18
 
19
 
 
20
  INPUT_PROMPT = '''
21
  Here is the frist few samples of data:
22
  <Sample Data>
@@ -25,6 +28,19 @@ Here is the frist few samples of data:
25
  '''
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  print('Connecting to DB...')
29
  # Connect to DB
30
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
@@ -53,11 +69,23 @@ def get_data_df(schema):
53
  print('Getting Dataframe from the Database')
54
  return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df()
55
 
56
- def run_llm(df):
57
- messages=[
58
- {"role": "system", "content": PROMPT_PANDERA},
59
- {"role": "user", "content": INPUT_PROMPT.format(data=df.head().to_json(orient='records'))},
 
 
60
  ]
 
 
 
 
 
 
 
 
 
 
61
  try:
62
  response = client.chat_completion(messages, max_tokens=1024)
63
  print(response.choices[0].message.content)
@@ -149,8 +177,10 @@ def main(table):
149
  df = get_data_df(schema)
150
  df_statistics, df_alerts = statistics(df)
151
  describe_num, describe_cat = describe(df)
 
 
152
 
153
- tests = run_llm(df)
154
  print(tests)
155
  if isinstance(tests, Exception):
156
  tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}])
@@ -162,10 +192,34 @@ def main(table):
162
 
163
  return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests_df, pandera_results
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  # Custom CSS styling
166
  custom_css = """
 
167
  .gradio-container {
168
  background-color: #f0f4f8;
 
169
  }
170
  .logo {
171
  max-width: 200px;
@@ -196,7 +250,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
196
  schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
197
  tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
198
  with gr.Row():
199
- generate_query_button = gr.Button("Validate Data", variant="primary")
200
 
201
  with gr.Column(scale=2):
202
  with gr.Tabs():
@@ -220,11 +274,26 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
220
 
221
  with gr.Tab("Data"):
222
  result_output = gr.DataFrame(label="Dataframe (10 Rows)", value=[], interactive=False)
223
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  schema_dropdown.change(update_table_names, inputs=schema_dropdown, outputs=tables_dropdown)
225
- generate_query_button.click(main, inputs=[tables_dropdown], outputs=[result_output, data_description, data_alerts, describe_cat, describe_num, tests_output, test_result_output])
226
-
227
-
228
 
229
  if __name__ == "__main__":
230
  demo.launch(debug=True)
 
 
7
  from pandera import Column
8
  import ydata_profiling as pp
9
  from huggingface_hub import InferenceClient
10
+ from prompt import PROMPT_PANDERA, PANDERA_USER_INPUT_PROMPT
11
+ import warnings
12
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
13
 
14
 
15
  # Height of the Tabs Text Area
 
19
  os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
20
 
21
 
22
+
23
  INPUT_PROMPT = '''
24
  Here is the frist few samples of data:
25
  <Sample Data>
 
28
  '''
29
 
30
 
31
+ USER_INPUT = '''
32
+ Here is the frist few samples of data:
33
+ <Sample Data>
34
+ {data}
35
+ </Sample Data<>
36
+
37
+ Here is the User Description:
38
+ <User Description>
39
+ {user_description}
40
+ </User Description>
41
+ '''
42
+
43
+
44
  print('Connecting to DB...')
45
  # Connect to DB
46
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
 
69
  print('Getting Dataframe from the Database')
70
  return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df()
71
 
72
+
73
+ def chat_template(system_prompt, user_prompt, df):
74
+
75
+ messages=[
76
+ {"role": "system", "content": system_prompt},
77
+ {"role": "user", "content": user_prompt.format(data=df.head().to_json(orient='records'))},
78
  ]
79
+ return messages
80
+ def chat_template_user(system_prompt, user_prompt, user_description, df):
81
+
82
+ messages=[
83
+ {"role": "system", "content": system_prompt},
84
+ {"role": "user", "content": user_prompt.format(data=df.head(1).to_json(orient='records'), user_description=user_description)},
85
+ ]
86
+ return messages
87
+
88
+ def run_llm(messages):
89
  try:
90
  response = client.chat_completion(messages, max_tokens=1024)
91
  print(response.choices[0].message.content)
 
177
  df = get_data_df(schema)
178
  df_statistics, df_alerts = statistics(df)
179
  describe_num, describe_cat = describe(df)
180
+
181
+ messages = chat_template(system_prompt=PROMPT_PANDERA, user_prompt=INPUT_PROMPT, df=df)
182
 
183
+ tests = run_llm(messages)
184
  print(tests)
185
  if isinstance(tests, Exception):
186
  tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}])
 
192
 
193
  return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests_df, pandera_results
194
 
195
+ def user_results(table, text_query):
196
+
197
+ schema = get_table_schema(table)
198
+ df = get_data_df(schema)
199
+
200
+ messages = chat_template_user(system_prompt=PANDERA_USER_INPUT_PROMPT,
201
+ user_prompt=USER_INPUT, user_description=text_query,
202
+ df=df)
203
+ print(messages)
204
+ tests = run_llm(messages)
205
+ print(f'Generated Tests from user input: {tests}')
206
+
207
+ if isinstance(tests, Exception):
208
+ tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}])
209
+ return tests, pd.DataFrame([])
210
+
211
+ tests_df = pd.DataFrame(tests)
212
+ tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True)
213
+ pandera_results = validate_pandera(tests, df)
214
+
215
+ return tests_df, pandera_results
216
+
217
  # Custom CSS styling
218
  custom_css = """
219
+ print('Validated Tests with Pandera')
220
  .gradio-container {
221
  background-color: #f0f4f8;
222
+
223
  }
224
  .logo {
225
  max-width: 200px;
 
250
  schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
251
  tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
252
  with gr.Row():
253
+ generate_result = gr.Button("Validate Data", variant="primary")
254
 
255
  with gr.Column(scale=2):
256
  with gr.Tabs():
 
274
 
275
  with gr.Tab("Data"):
276
  result_output = gr.DataFrame(label="Dataframe (10 Rows)", value=[], interactive=False)
277
+
278
+ with gr.Tab('Text to Validation'):
279
+ with gr.Row():
280
+ query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter Text Query to Generate Validation e.g. Validate that the incident_zip column contains valid 5-digit ZIP codes.")
281
+ with gr.Row():
282
+ with gr.Column():
283
+ pass
284
+ with gr.Column(scale=1, min_width=50):
285
+ user_generate_result = gr.Button("Validate Data", variant="primary" )
286
+
287
+ with gr.Row():
288
+ with gr.Column():
289
+ query_tests = gr.DataFrame(label="Validation Rules", value=[], interactive=False)
290
+ with gr.Column():
291
+ query_result = gr.DataFrame(label="Validation Result", value=[], interactive=False)
292
+
293
  schema_dropdown.change(update_table_names, inputs=schema_dropdown, outputs=tables_dropdown)
294
+ generate_result.click(main, inputs=[tables_dropdown], outputs=[result_output, data_description, data_alerts, describe_cat, describe_num, tests_output, test_result_output])
295
+ user_generate_result.click(user_results, inputs=[tables_dropdown, query_input], outputs=[query_tests, query_result])
 
296
 
297
  if __name__ == "__main__":
298
  demo.launch(debug=True)
299
+
prompt.py CHANGED
@@ -66,3 +66,56 @@ Return the final rules as a single JSON object, ensuring that each column is tho
66
 
67
  DO NOT OUTPUT ANYTHING OR ANY EXPLAINATION OTHER THAN JSON OBJECT
68
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  DO NOT OUTPUT ANYTHING OR ANY EXPLAINATION OTHER THAN JSON OBJECT
68
  """
69
+
70
+
71
+ PANDERA_USER_INPUT_PROMPT = """
72
+ You are a data quality engineer. Your role is to assist the user in creating deterministic rules to validate the quality of a dataset using **Pandera**.
73
+ You will be provided with the first few rows of data below that represents the dataset for which you need to help the user create validation rules. Please note that this is only a sample of the data, and there may be additional rows and categorical columns that are not fully represented in the sample. Keep in mind that the sample may not cover all possible values, but the validation rules must handle all data in the dataset.
74
+
75
+ Follow this process:
76
+
77
+ 1. **Observe the sample data.**
78
+ 2. Observe description and create a valid pander check
79
+
80
+ Here are the valid **Pandera** Checks that you can use:
81
+ 1. 'pa.Check.between(min_value, max_value, include_min=True, include_max=True, **kwargs)'
82
+ 2. 'pa.Check.eq(value, **kwargs)' Checks if a value is equal to the specified value.
83
+ 3. 'pa.Check.equal_to(value, **kwargs)' Alias for eq(). Checks if a value is equal to the specified value.
84
+ 4. 'pa.Check.ge(min_value, **kwargs)' Checks if a value is greater than or equal to the specified minimum value.
85
+ 5. 'pa.Check.greater_than(min_value, **kwargs)' Checks if a value is strictly greater than the specified minimum value.
86
+ 6. 'pa.Check.greater_than_or_equal_to(min_value, **kwargs)' Checks if a value is greater than or equal to the specified minimum value.
87
+ 7. 'pa.Check.gt(min_value, **kwargs)' Alias for greater_than(). Checks if a value is strictly greater than the specified minimum value.
88
+ 8. 'pa.Check.in_range(min_value, max_value, include_min=True, include_max=True, **kwargs)' Checks if a value is within the specified range. By default, it's inclusive of both min and max values.
89
+ 9. 'pa.Check.isin(allowed_values, **kwargs)' Checks if a value is in the set of allowed values.
90
+ 10. 'pa.Check.le(max_value, **kwargs)' Checks if a value is less than or equal to the specified maximum value.
91
+ 11. 'pa.Check.less_than(max_value, **kwargs)' ): Checks if a value is strictly less than the specified maximum value.
92
+ 12. 'pa.Check.less_than_or_equal_to(max_value, **kwargs)' Checks if a value is less than or equal to the specified maximum value.
93
+ 13. 'pa.Check.lt(max_value, **kwargs)' Checks if a value is strictly less than the specified maximum value.
94
+ 14. 'pa.Check.ne(value, **kwargs)' Checks if a value is not equal to the specified value.
95
+ 15. 'pa.Check.not_equal_to(value, **kwargs)' Checks if a value is not equal to the specified value.
96
+ 16. 'pa.Check.notin(forbidden_values, **kwargs)' Checks if a value is not in the set of forbidden values.
97
+ 17. 'pa.Check.str_contains(pattern, **kwargs)' Checks if a string contains the specified pattern.
98
+ 18. 'pa.Check.str_endswith(string, **kwargs)' Checks if a string ends with the specified substring.
99
+ 19. 'pa.Check.str_length(min_value=None, max_value=None, **kwargs)' Checks if the length of a string is within the specified range.
100
+ 20. 'pa.Check.str_matches(pattern, **kwargs)' Checks if a string matches the specified regular expression pattern.
101
+ 21. 'pa.Check.str_startswith(string, **kwargs)' Checks if a string starts with the specified substring.
102
+ 22. 'pa.Check.unique_values_eq(values, **kwargs)' Checks if the unique values in a column are equal to the specified set of values.
103
+ 23. 'pa.Check(lambda x: x )' with lambda functions for custom logic.
104
+ 24. 'pa.Column(int, nullable=False, unique=True, name='column_name') For unqiue values
105
+ **ALWAY USE THE COMPLETE PANDERA SYNTAX
106
+
107
+ 3. For each column, generate a **column name**, **rule name**, and a **Pandera rule** based on the user’s description. Example structure:
108
+
109
+ ```json
110
+ [
111
+ {
112
+ "column_name": "OS",
113
+ "rule_name": "Allowed Operating Systems",
114
+ "pandera_rule": "Column(str, pa.Check.isin(['macOS', 'Windows', 'Linux']), nullable=False, name='OS')"
115
+ }
116
+ ]
117
+
118
+ 4. Repeat this process for a maximum of 5 columns or based on user input. Group all the rules into a single JSON object and return it.
119
+ IMPORTANT: You should only generate rules based on the user’s input for each column. Return the final rules as a single JSON object, ensuring that the user's instructions are reflected in the validations.
120
+
121
+ DO NOT RETURN ANYTHING OR ANY EXPLANATION OTHER THAN JSON """