Mustehson commited on
Commit
6dda383
·
1 Parent(s): d14334a

Refactoring & Logs

Browse files
Files changed (3) hide show
  1. app.py +20 -48
  2. prompt.py +0 -123
  3. requirements.txt +2 -1
app.py CHANGED
@@ -6,10 +6,9 @@ import pandas as pd
6
  import pandera as pa
7
  from pandera import Column
8
  import ydata_profiling as pp
9
- from langchain_core.messages import HumanMessage, SystemMessage
10
  from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
11
- from prompt import PROMPT_PANDERA, PANDERA_USER_INPUT_PROMPT
12
  from langsmith import traceable
 
13
  import warnings
14
  warnings.filterwarnings("ignore", category=DeprecationWarning)
15
 
@@ -18,29 +17,6 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
18
  TAB_LINES = 8
19
  # Load Token
20
  md_token = os.getenv('MD_TOKEN')
21
-
22
- INPUT_PROMPT = '''
23
- Here are the first few samples of data:
24
- <Sample Data>
25
- {data}
26
- </Sample Data<>
27
- '''
28
-
29
-
30
- USER_INPUT = '''
31
- Here are the first few samples of data:
32
- <Sample Data>
33
- {data}
34
- </Sample Data<>
35
-
36
- Here is the User Description:
37
- <User Description>
38
- {user_description}
39
- </User Description>
40
- '''
41
-
42
-
43
- print('Connecting to DB...')
44
  # Connect to DB
45
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
46
 
@@ -60,6 +36,12 @@ for model in models:
60
 
61
  llm = ChatHuggingFace(llm=endpoint).bind_tools(tools=[], max_tokens=8192)
62
 
 
 
 
 
 
 
63
  # Get Databases
64
  def get_schemas():
65
  schemas = conn.execute("""
@@ -84,28 +66,20 @@ def get_data_df(schema):
84
  return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df()
85
 
86
 
87
- def chat_template(system_prompt, user_prompt, df):
88
-
89
- messages = [
90
- SystemMessage(content=system_prompt),
91
- HumanMessage(content=user_prompt.format(data=df.head().to_json(orient='records'))),
92
- ]
93
- return messages
94
-
95
- def chat_template_user(system_prompt, user_prompt, user_description, df):
96
 
97
- messages = [
98
- SystemMessage(content=system_prompt),
99
- HumanMessage(content=user_prompt.format(data=df.head(1).to_json(orient='records'), user_description=user_description)),
100
- ]
101
- return messages
102
 
 
 
 
103
 
104
- @traceable()
105
  def run_llm(messages):
106
  try:
107
  response = llm.invoke(messages)
108
- print(response.content)
109
  tests = json.loads(response.content)
110
  except Exception as e:
111
  return e
@@ -199,11 +173,11 @@ def main(table):
199
  df = get_data_df(schema)
200
  df_statistics, df_alerts = statistics(df)
201
  describe_num, describe_cat = describe(df)
202
-
203
- messages = chat_template(system_prompt=PROMPT_PANDERA, user_prompt=INPUT_PROMPT, df=df)
204
-
205
  tests = run_llm(messages)
206
  print(tests)
 
207
  if isinstance(tests, Exception):
208
  tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}])
209
  return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests, pd.DataFrame([])
@@ -219,11 +193,9 @@ def user_results(table, text_query):
219
  schema = get_table_schema(table)
220
  df = get_data_df(schema)
221
 
222
- messages = chat_template_user(system_prompt=PANDERA_USER_INPUT_PROMPT,
223
- user_prompt=USER_INPUT, user_description=text_query,
224
- df=df)
225
- print(messages)
226
  tests = run_llm(messages)
 
227
  print(f'Generated Tests from user input: {tests}')
228
 
229
  if isinstance(tests, Exception):
 
6
  import pandera as pa
7
  from pandera import Column
8
  import ydata_profiling as pp
 
9
  from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
 
10
  from langsmith import traceable
11
+ from langchain import hub
12
  import warnings
13
  warnings.filterwarnings("ignore", category=DeprecationWarning)
14
 
 
17
  TAB_LINES = 8
18
  # Load Token
19
  md_token = os.getenv('MD_TOKEN')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # Connect to DB
21
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
22
 
 
36
 
37
  llm = ChatHuggingFace(llm=endpoint).bind_tools(tools=[], max_tokens=8192)
38
 
39
+
40
+
41
+ prompt_autogenerate = hub.pull("autogenerate-rules-testworkflow")
42
+ prompt_user_input = hub.pull("usergenerate-rules-testworkflow")
43
+
44
+
45
  # Get Databases
46
  def get_schemas():
47
  schemas = conn.execute("""
 
66
  return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df()
67
 
68
 
69
+ def format_prompt(df):
70
+ return prompt_autogenerate.format_prompt(data=df.head().to_json(orient='records'))
71
+ def format_user_prompt(df, user_description):
72
+ return prompt_user_input.format_prompt(data=df.head(2).to_json(orient='records'), user_description=user_description)
 
 
 
 
 
73
 
 
 
 
 
 
74
 
75
+ def process_inputs(inputs) :
76
+ print(inputs)
77
+ return {'input_query': inputs['messages'].to_messages()[1]}
78
 
79
+ @traceable(process_inputs=process_inputs)
80
  def run_llm(messages):
81
  try:
82
  response = llm.invoke(messages)
 
83
  tests = json.loads(response.content)
84
  except Exception as e:
85
  return e
 
173
  df = get_data_df(schema)
174
  df_statistics, df_alerts = statistics(df)
175
  describe_num, describe_cat = describe(df)
176
+
177
+ messages = format_prompt(df=df)
 
178
  tests = run_llm(messages)
179
  print(tests)
180
+
181
  if isinstance(tests, Exception):
182
  tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}])
183
  return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests, pd.DataFrame([])
 
193
  schema = get_table_schema(table)
194
  df = get_data_df(schema)
195
 
196
+ messages = format_user_prompt(df=df, user_description=text_query)
 
 
 
197
  tests = run_llm(messages)
198
+
199
  print(f'Generated Tests from user input: {tests}')
200
 
201
  if isinstance(tests, Exception):
prompt.py DELETED
@@ -1,123 +0,0 @@
1
- PROMPT_PANDERA = """
2
- You are a data quality engineer. Your role is to create deterministic rules to validate the quality of a dataset using **Pandera**.
3
- You will be provided with the first few rows of data below that represents the dataset for which you need to create validation rules. Please note that this is only a sample of the data, and there may be additional rows and categorical columns that are not fully represented in the sample. Keep in mind that the sample may not cover all possible values, but the validation rules must handle all data in the dataset.
4
-
5
- Follow this process:
6
-
7
- 1. **Observe the sample data.**
8
- 2. **For each column**, create a validation rule using Pandera syntax.
9
- Here are the valid pandera check class methods DO NOT USE ANYOTHER METHODS OTHER THAN THE BELOW GIVEN METHODS:
10
- DO NOT USE SINGLE backslashes \ BUT USE DOUBLE backslashes \\ IN PATTERN
11
- USE CORRECT SYNTAX AS SHOWN GIVEN BELOW
12
- [
13
- 'pa.Check.between(min_value, max_value, include_min=True, include_max=True, **kwargs)',
14
- 'pa.Check.eq(value, **kwargs)',
15
- 'pa.Check.equal_to(value, **kwargs)',
16
- 'pa.Check.ge(min_value, **kwargs)',
17
- 'pa.Check.greater_than(min_value, **kwargs)',
18
- 'pa.Check.greater_than_or_equal_to(min_value, **kwargs)',
19
- 'pa.Check.gt(min_value, **kwargs)',
20
- 'pa.Check.in_range(min_value, max_value, include_min=True, include_max=True, **kwargs)',
21
- 'pa.Check.isin(allowed_values, **kwargs)',
22
- 'pa.Check.le(max_value, **kwargs)',
23
- 'pa.Check.less_than(max_value, **kwargs)',
24
- 'pa.Check.less_than_or_equal_to(max_value, **kwargs)',
25
- 'pa.Check.lt(max_value, **kwargs)',
26
- 'pa.Check.ne(value, **kwargs)',
27
- 'pa.Check.not_equal_to(value, **kwargs)',
28
- 'pa.Check.notin(forbidden_values, **kwargs)',
29
- 'pa.Check.str_contains(pattern, **kwargs)',
30
- 'pa.Check.str_endswith(string, **kwargs)',
31
- 'pa.Check.str_length(min_value=None, max_value=None, **kwargs)',
32
- 'pa.Check.str_matches(pattern, **kwargs)',
33
- 'pa.Check.str_startswith(string, **kwargs)',
34
- 'pa.Check.unique_values_eq(values, **kwargs)'
35
- ]
36
-
37
-
38
- 3. Ensure that each rule specifies the expected data type and applies necessary checks such as:
39
- name argument should be a valid column name. DO NOT USE ANYOTHER PANDERA
40
- - **Data Type Validation** (e.g., `pa.Column(int, nullable=False, name="age")` ensures integers)
41
- - **Non-null Check** (e.g., `pa.Column(str, nullable=False, name="name")` to ensure no nulls are allowed)
42
- - **Unique Value Check** (e.g., `pa.Column(int, unique=True, name="ID")` for uniqueness)
43
- - **Range or Bound Checks** (e.g., `pa.Column(float, checks=pa.Check.in_range(min_value=0, max_value=100), name="score")` for numerical ranges)
44
- - **Allowed Value Checks** (e.g., `pa.Column(str, checks=pa.Check.isin([value1, value2]), name="gender")` to restrict values to a set)
45
- - **Custom Validation Logic** using `pa.Column(int, checks=pa.Check(lambda x: x % 2 == 0), name="even_number")` with lambda functions (e.g., custom logic for even numbers or string patterns)
46
- FOR DATETIME OR DATE COLUMN USE THE BELOW VALIDATION DO NOT CONISER IT AS INT OR FLOAT
47
- - **DateTime or Date Validation** (e.g., `pa.Column(pa.dtypes.Timestamp, nullable=False), name="date_column")` to ensure dates or datetime)
48
-
49
- For each column, provide a **column name**, **rule name** and a pandera_rule. Example structure (It should be list of dicts):
50
-
51
-
52
- [
53
- {
54
- "column_name": "age",
55
- "rule_name": "Ensure Column is Integer",
56
- "pandera_rule": "Column(int, nullable=False, name='age')"
57
- },
58
- {
59
- "column_name": "ID",
60
- "rule_name": "Unique Identifier Check",
61
- "pandera_rule": "Column(int, unique=True, name='ID')"
62
- }
63
- ]
64
-
65
-
66
- 3 Repeat this process for max 5 columns in the dataset. If the data is less than 5 columns than include all columns. Group all the rules into a single JSON object and ensure that there is at least one validation rule for each column.
67
- Return the final rules as a single JSON object, ensuring that each column is thoroughly validated based on the observations of the sample data.
68
-
69
- DO NOT OUTPUT ANYTHING OR ANY EXPLAINATION OTHER THAN JSON OBJECT
70
- """
71
-
72
-
73
- PANDERA_USER_INPUT_PROMPT = """
74
- You are a data quality engineer. Your role is to assist the user in creating deterministic rules to validate the quality of a dataset using **Pandera**.
75
- You will be provided with the first few rows of data below that represents the dataset for which you need to help the user create validation rules. Please note that this is only a sample of the data, and there may be additional rows and categorical columns that are not fully represented in the sample. Keep in mind that the sample may not cover all possible values, but the validation rules must handle all data in the dataset.
76
-
77
- Follow this process:
78
-
79
- 1. **Observe the sample data.**
80
- 2. Observe description and create a valid check
81
-
82
- Here are the valid **Pandera** Checks that you can use:
83
- 1. 'pa.Check.between(min_value, max_value, include_min=True, include_max=True, **kwargs)'
84
- 2. 'pa.Check.eq(value, **kwargs)' Checks if a value is equal to the specified value.
85
- 3. 'pa.Check.equal_to(value, **kwargs)' Alias for eq(). Checks if a value is equal to the specified value.
86
- 4. 'pa.Check.ge(min_value, **kwargs)' Checks if a value is greater than or equal to the specified minimum value.
87
- 5. 'pa.Check.greater_than(min_value, **kwargs)' Checks if a value is strictly greater than the specified minimum value.
88
- 6. 'pa.Check.greater_than_or_equal_to(min_value, **kwargs)' Checks if a value is greater than or equal to the specified minimum value.
89
- 7. 'pa.Check.gt(min_value, **kwargs)' Alias for greater_than(). Checks if a value is strictly greater than the specified minimum value.
90
- 8. 'pa.Check.in_range(min_value, max_value, include_min=True, include_max=True, **kwargs)' Checks if a value is within the specified range. By default, it's inclusive of both min and max values.
91
- 9. 'pa.Check.isin(allowed_values, **kwargs)' Checks if a value is in the set of allowed values.
92
- 10. 'pa.Check.le(max_value, **kwargs)' Checks if a value is less than or equal to the specified maximum value.
93
- 11. 'pa.Check.less_than(max_value, **kwargs)' ): Checks if a value is strictly less than the specified maximum value.
94
- 12. 'pa.Check.less_than_or_equal_to(max_value, **kwargs)' Checks if a value is less than or equal to the specified maximum value.
95
- 13. 'pa.Check.lt(max_value, **kwargs)' Checks if a value is strictly less than the specified maximum value.
96
- 14. 'pa.Check.ne(value, **kwargs)' Checks if a value is not equal to the specified value.
97
- 15. 'pa.Check.not_equal_to(value, **kwargs)' Checks if a value is not equal to the specified value.
98
- 16. 'pa.Check.notin(forbidden_values, **kwargs)' Checks if a value is not in the set of forbidden values.
99
- 17. 'pa.Check.str_contains(pattern, **kwargs)' Checks if a string contains the specified pattern.
100
- 18. 'pa.Check.str_endswith(string, **kwargs)' Checks if a string ends with the specified substring.
101
- 19. 'pa.Check.str_length(min_value=None, max_value=None, **kwargs)' Checks if the length of a string is within the specified range.
102
- 20. 'pa.Check.str_matches(pattern, **kwargs)' Checks if a string matches the specified regular expression pattern.
103
- 21. 'pa.Check.str_startswith(string, **kwargs)' Checks if a string starts with the specified substring.
104
- 22. 'pa.Check.unique_values_eq(values, **kwargs)' Checks if the unique values in a column are equal to the specified set of values.
105
- 23. 'pa.Check(lambda x: x )' with lambda functions for custom logic.
106
- **ALWAY USE THE COMPLETE PANDERA SYNTAX**
107
-
108
- 3. For each column, generate a **column name**, **rule name**, and a **Pandera rule** based on the user’s description. Example structure (It should be list of dicts):
109
-
110
-
111
- [
112
- {
113
- "column_name": "unique_key",
114
- "rule_name": "Unique Identifiers",
115
- "pandera_rule": "pa.Column(int, nullable=False, unique=True, name='unique_key')"
116
- }
117
- ]
118
-
119
-
120
- 4. Repeat this process for a maximum of 5 columns or based on user input. Group all the rules into a single JSON object and return it.
121
- IMPORTANT: You should only generate rules based on the user’s input for each column. Return the final rules as a single JSON object, ensuring that the user's instructions are reflected in the validations.
122
-
123
- DO NOT RETURN ANYTHING OR ANY EXPLANATION OTHER THAN JSON """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -7,4 +7,5 @@ duckdb==1.1.1
7
  langsmith==0.1.135
8
  pandera==0.20.4
9
  ydata-profiling==v4.11.0
10
- langchain-core==0.3.12
 
 
7
  langsmith==0.1.135
8
  pandera==0.20.4
9
  ydata-profiling==v4.11.0
10
+ langchain-core==0.3.12
11
+ langchain==0.3.4