Mustehson
LLM Response
b67984f
import os
import json
import duckdb
import gradio as gr
import pandas as pd
import pandera as pa
from pandera import Column
import ydata_profiling as pp
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langsmith import traceable
from langchain import hub
import warnings
import dlt
warnings.filterwarnings("ignore", category=DeprecationWarning)
# Height of the Tabs Text Area
TAB_LINES = 8
#----------CONNECT TO DATABASE----------
md_token = os.getenv('MD_TOKEN')
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
#---------------------------------------
#-------LOAD HUGGINGFACE-------
models = ["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3-70B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct"]
model_loaded = False
for model in models:
try:
endpoint = HuggingFaceEndpoint(repo_id=model, max_new_tokens=8192)
info = endpoint.client.get_endpoint_info()
model_loaded = True
break
except Exception as e:
print(f"Error for model {model}: {e}")
continue
llm = ChatHuggingFace(llm=endpoint).bind(max_tokens=8192)
#---------------------------------------
#-----LOAD PROMPT FROM LANCHAIN HUB-----
prompt_autogenerate = hub.pull("autogenerate-rules-testworkflow")
prompt_user_input = hub.pull("usergenerate-rules-testworkflow")
#--------------ALL UTILS----------------
# Get Databases
def get_schemas():
schemas = conn.execute("""
SELECT DISTINCT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('information_schema', 'pg_catalog')
""").fetchall()
return [item[0] for item in schemas]
# Get Tables
def get_tables_names(schema_name):
tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall()
return [table[0] for table in tables]
# Update Tables
def update_table_names(schema_name):
tables = get_tables_names(schema_name)
return gr.update(choices=tables)
# def get_data_df(schema):
# print('Getting Dataframe from the Database')
# return conn.sql(f"SELECT * FROM {schema} LIMIT 1000")
@dlt.resource
def fetch_data(schema):
result = conn.sql(f"SELECT * FROM {schema} LIMIT 1000")
while True:
chunk_df = result.fetch_df_chunk(2)
if chunk_df is None or len(chunk_df) == 0:
break
else:
yield chunk_df
def create_pipeline(schema):
dataset_name = schema.split('.')[1]
print("Dataset Name: ", dataset_name)
table_name = schema.split('.')[2]
print("Table Name: ", table_name)
pipeline =dlt.pipeline(
pipeline_name='duckdb_pipeline',
destination='duckdb',
dataset_name= dataset_name,
)
load_info = pipeline.run(fetch_data(schema), table_name = table_name,
write_disposition = "replace")
print(load_info)
return dataset_name + "." + table_name
def load_pipeline(table_name):
_conn = duckdb.connect("duckdb_pipeline.duckdb")
return _conn, _conn.sql(f"SELECT * FROM {table_name} LIMIT 1000").df()
def df_summary(df):
summary = []
for column in df.columns:
if pd.api.types.is_numeric_dtype(df[column]):
summary.append({
"column": column,
"max": df[column].max(),
"min": df[column].min(),
"count": df[column].count(),
"nunique": df[column].nunique(),
"dtype": str(df[column].dtype),
"top": None
})
elif pd.api.types.is_categorical_dtype(df[column]) or pd.api.types.is_object_dtype(df[column]):
top_value = df[column].mode().iloc[0] if not df[column].mode().empty else None
summary.append({
"column": column,
"max": None,
"min": None,
"count": df[column].count(),
"nunique": df[column].nunique(),
"dtype": str(df[column].dtype),
"top": top_value
})
summary_df = pd.DataFrame(summary)
return summary_df.reset_index(drop=True)
def format_prompt(df):
summary = df_summary(df)
return prompt_autogenerate.format_prompt(data=df.head().to_json(orient='records'),
summary=summary.to_json(orient='records'))
def format_user_prompt(df):
return prompt_user_input.format_prompt(data=df.head().to_json(orient='records'))
def process_inputs(inputs) :
return {'input_query': inputs['messages'].to_messages()[1]}
@traceable(process_inputs=process_inputs)
def run_llm(messages):
try:
response = llm.invoke(messages)
print(response.content.replace("```", "'''").replace("json", ""))
tests = json.loads(response.content.replace("```", "").replace("json", ""))
except Exception as e:
return e
return tests
# Get Schema
def get_table_schema(table):
result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df()
ddl_create = result.iloc[0,0]
parent_database = result.iloc[0,1]
schema_name = result.iloc[0,2]
full_path = f"{parent_database}.{schema_name}.{table}"
if schema_name != "main":
old_path = f"{schema_name}.{table}"
else:
old_path = table
ddl_create = ddl_create.replace(old_path, full_path)
return full_path
def describe(df):
numerical_info = pd.DataFrame()
categorical_info = pd.DataFrame()
if len(df.select_dtypes(include=['number']).columns) >= 1:
numerical_info = df.select_dtypes(include=['number']).describe().T.reset_index()
numerical_info.rename(columns={'index': 'column'}, inplace=True)
if len(df.select_dtypes(include=['object']).columns) >= 1:
categorical_info = df.select_dtypes(include=['object']).describe().T.reset_index()
categorical_info.rename(columns={'index': 'column'}, inplace=True)
return numerical_info, categorical_info
def validate_pandera(tests, df):
validation_results = []
for test in tests:
column_name = test['column_name']
try:
rule = eval(test['pandera_rule'])
validated_column = rule(df[[column_name]])
validation_results.append({
"Columns": column_name,
"Result": "✅ Pass"
})
except Exception as e:
validation_results.append({
"Columns": column_name,
"Result": f"❌ Fail - {str(e)}"
})
return pd.DataFrame(validation_results)
def statistics(df):
profile = pp.ProfileReport(df)
report_dict = profile.get_description()
description, alerts = report_dict.table, report_dict.alerts
# Statistics
mapping = {
'n': 'Number of observations',
'n_var': 'Number of variables',
'n_cells_missing': 'Number of cells missing',
'n_vars_with_missing': 'Number of columns with missing data',
'n_vars_all_missing': 'Columns with all missing data',
'p_cells_missing': 'Missing cells (%)',
'n_duplicates': 'Duplicated rows',
'p_duplicates': 'Duplicated rows (%)',
}
updated_data = {mapping.get(k, k): v for k, v in description.items() if k != 'types'}
# Add flattened types information
if 'Text' in description.get('types', {}):
updated_data['Number of text columns'] = description['types']['Text']
if 'Categorical' in description.get('types', {}):
updated_data['Number of categorical columns'] = description['types']['Categorical']
if 'Numeric' in description.get('types', {}):
updated_data['Number of numeric columns'] = description['types']['Numeric']
if 'DateTime' in description.get('types', {}):
updated_data['Number of datetime columns'] = description['types']['DateTime']
df_statistics = pd.DataFrame(list(updated_data.items()), columns=['Statistic Description', 'Value'])
df_statistics['Value'] = df_statistics['Value'].astype(int)
# Alerts
alerts_list = [(str(alert).replace('[', '').replace(']', ''), alert.alert_type_name) for alert in alerts]
df_alerts = pd.DataFrame(alerts_list, columns=['Data Quality Issue', 'Category'])
return df_statistics, df_alerts
#---------------------------------------
# Main Function
def main(table):
schema = get_table_schema(table)
# Create dlt pipeline
table_name = create_pipeline(schema)
# Load dlt pipeline
connection, df = load_pipeline(table_name)
# df = get_data_df(schema)
df_statistics, df_alerts = statistics(df)
describe_num, describe_cat = describe(df)
messages = format_prompt(df=df)
tests = run_llm(messages)
if isinstance(tests, Exception):
tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}])
return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests, pd.DataFrame([])
tests_df = pd.DataFrame(tests)
tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True)
pandera_results = validate_pandera(tests, df)
connection.close()
return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests_df, pandera_results
def user_results(table, text_query):
schema = get_table_schema(table)
# Create dlt pipeline
table_name = create_pipeline(schema)
# Load dlt pipeline
connection, df = load_pipeline(table_name)
messages = format_user_prompt(df=df, user_description=text_query)
print(f'Generated Tests from user input: {tests}')
if isinstance(tests, Exception):
tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}])
return tests, pd.DataFrame([])
tests_df = pd.DataFrame(tests)
tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True)
pandera_results = validate_pandera(tests, df)
connection.close()
return tests_df, pandera_results
# Custom CSS styling
custom_css = """
print('Validated Tests with Pandera')
.gradio-container {
background-color: #f0f4f8;
}
.logo {
max-width: 200px;
margin: 20px auto;
display: block;
}
.gr-button {
background-color: #4a90e2 !important;
}
.gr-button:hover {
background-color: #3a7bc8 !important;
}
"""
with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo:
gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
gr.Markdown("""
<div style='text-align: center;'>
<strong style='font-size: 36px;'>Dataset Test Workflow</strong>
<br>
<span style='font-size: 20px;'>Implement and Automate Data Validation Processes.</span>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
with gr.Row():
generate_result = gr.Button("Validate Data", variant="primary")
with gr.Column(scale=2):
with gr.Tabs():
with gr.Tab("Description"):
with gr.Row():
with gr.Column():
data_description = gr.DataFrame(label="Data Description", value=[], interactive=False)
with gr.Row():
with gr.Column():
describe_cat = gr.DataFrame(label="Categorical Information", value=[], interactive=False)
with gr.Column():
describe_num = gr.DataFrame(label="Numerical Information", value=[], interactive=False)
with gr.Tab("Alerts"):
data_alerts = gr.DataFrame(label="Alerts", value=[], interactive=False)
with gr.Tab("Rules & Validations"):
tests_output = gr.DataFrame(label="Validation Rules", value=[], interactive=False)
test_result_output = gr.DataFrame(label="Validation Result", value=[], interactive=False)
with gr.Tab("Data"):
result_output = gr.DataFrame(label="Dataframe (10 Rows)", value=[], interactive=False)
with gr.Tab('Text to Validation'):
with gr.Row():
query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter Text Query to Generate Validation e.g. Validate that the incident_zip column contains valid 5-digit ZIP codes.")
with gr.Row():
with gr.Column():
pass
with gr.Column(scale=1, min_width=50):
user_generate_result = gr.Button("Validate Data", variant="primary" )
with gr.Row():
with gr.Column():
query_tests = gr.DataFrame(label="Validation Rules", value=[], interactive=False)
with gr.Column():
query_result = gr.DataFrame(label="Validation Result", value=[], interactive=False)
schema_dropdown.change(update_table_names, inputs=schema_dropdown, outputs=tables_dropdown)
generate_result.click(main, inputs=[tables_dropdown], outputs=[result_output, data_description, data_alerts, describe_cat, describe_num, tests_output, test_result_output])
user_generate_result.click(user_results, inputs=[tables_dropdown, query_input], outputs=[query_tests, query_result])
if __name__ == "__main__":
demo.launch(debug=True)