import os import json import duckdb import gradio as gr import pandas as pd import pandera as pa from pandera import Column import random from dataprep.eda import compute from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace from .utils import ( format_num_stats, format_cat_stats, format_ov_stats, format_insights ) from langsmith import traceable from langchain import hub import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) # Height of the Tabs Text Area TAB_LINES = 8 #----------CONNECT TO DATABASE---------- md_token = os.getenv('MD_TOKEN') conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) #--------------------------------------- #-------LOAD HUGGINGFACE------- models = ["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct"] model_loaded = False for model in models: try: endpoint = HuggingFaceEndpoint(repo_id=model, max_new_tokens=8192) info = endpoint.client.get_endpoint_info() model_loaded = True break except Exception as e: print(f"Error for model {model}: {e}") continue llm = ChatHuggingFace(llm=endpoint).bind(max_tokens=4096) #--------------------------------------- #-----LOAD PROMPT FROM LANCHAIN HUB----- prompt_autogenerate = hub.pull("autogenerate-rules-testworkflow") prompt_user_input = hub.pull("usergenerate-rules-testworkflow") #--------------ALL UTILS---------------- # Get Databases def get_schemas(): schemas = conn.execute(""" SELECT DISTINCT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('information_schema', 'pg_catalog') """).fetchall() return [item[0] for item in schemas] # Get Tables def get_tables_names(schema_name): tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall() return [table[0] for table in tables] # Update Tables def update_table_names(schema_name): tables = get_tables_names(schema_name) return gr.update(choices=tables) # Get Schema def get_table_schema(table): result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df() ddl_create = result.iloc[0,0] parent_database = result.iloc[0,1] schema_name = result.iloc[0,2] full_path = f"{parent_database}.{schema_name}.{table}" if schema_name != "main": old_path = f"{schema_name}.{table}" else: old_path = table ddl_create = ddl_create.replace(old_path, full_path) return full_path def get_data_df(schema): print('Getting Dataframe from the Database') return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df() <<<<<<< HEAD def calcualte_stats(df): indev_stats = [] cols = [] _df = df.copy() num_cols = _df.select_dtypes(include=['number'], exclude=['datetime']).columns cat_cols = _df.select_dtypes(include=['object'], exclude=['datetime']).columns _all_stats = compute(_df) all_stats = format_ov_stats(_all_stats['stats']) insights = format_insights(_all_stats['overview_insights']) for i, col in enumerate(random.sample(num_cols.tolist()+cat_cols.tolist(), 2)): _indv_data = compute(_df, col) if col in cat_cols: indev_data_cat = format_cat_stats(_indv_data["data"]) indev_stats.append(pd.DataFrame([indev_data_cat['Overview']], index=[f'{col}_stats']).T) elif col in num_cols: try: indev_data_num = format_num_stats(_indv_data["data"]) except: indev_data_num = format_cat_stats(_indv_data["data"]) indev_stats.append(pd.DataFrame([indev_data_num['Overview']], index=[f'{col}_stats']).T) return { "overall_stats": pd.DataFrame(all_stats[0], index=['Dataset Statistics']).T, "insights": insights, "stats_1": indev_stats[0], "stats_2": indev_stats[1] } def df_summary(df): summary = [] for column in df.columns: if pd.api.types.is_numeric_dtype(df[column]): summary.append({ "column": column, "max": df[column].max(), "min": df[column].min(), "count": df[column].count(), "nunique": df[column].nunique(), "dtype": str(df[column].dtype), "top": None }) elif pd.api.types.is_categorical_dtype(df[column]) or pd.api.types.is_object_dtype(df[column]): top_value = df[column].mode().iloc[0] if not df[column].mode().empty else None summary.append({ "column": column, "max": None, "min": None, "count": df[column].count(), "nunique": df[column].nunique(), "dtype": str(df[column].dtype), "top": top_value }) summary_df = pd.DataFrame(summary) return summary_df.reset_index(drop=True) ======= >>>>>>> parent of 7c2e7ac (Summary Added) def format_prompt(df): summary_df = pd.DataFrame({ "max": df.max(), "min": df.min(), "top": df.mode().iloc[0], "nunique": df.nunique(), "count": df.count(), "dtype": df.dtypes.astype(str) }).reset_index().rename(columns={"index": "column"}) return prompt_autogenerate.format_prompt(data=df.head().to_json(orient='records'), summary=summary_df.to_json(orient='records')) def format_user_prompt(df): return prompt_user_input.format_prompt(data=df.head().to_json(orient='records')) def process_inputs(inputs) : return {'input_query': inputs['messages'].to_messages()[1]} @traceable(process_inputs=process_inputs) def run_llm(messages): try: response = llm.invoke(messages) tests = json.loads(response.content) except Exception as e: return e return tests def validate_pandera(tests, df): validation_results = [] for test in tests: column_name = test['column_name'] try: rule = eval(test['pandera_rule']) validated_column = rule(df[[column_name]]) validation_results.append({ "Columns": column_name, "Result": "✅ Pass" }) except Exception as e: validation_results.append({ "Columns": column_name, "Result": f"❌ Fail - {str(e)}" }) return pd.DataFrame(validation_results) #--------------------------------------- # Main Function def main(table): schema = get_table_schema(table) df = get_data_df(schema) messages = format_prompt(df=df) tests = run_llm(messages) print(tests) stats = calcualte_stats(df) df_insights = stats['insights'] df_statistics = stats['overall_stats'] df_stat_1 = stats['stats_1'] df_stat_2 = stats['stats_2'] if isinstance(tests, Exception): tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}]) return df.head(10), df_statistics, df_insights, df_stat_1, df_stat_2, tests, pd.DataFrame([]) tests_df = pd.DataFrame(tests) tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True) pandera_results = validate_pandera(tests, df) return df.head(10), df_statistics, df_insights, df_stat_1, df_stat_2, tests_df, pandera_results def user_results(table, text_query): schema = get_table_schema(table) df = get_data_df(schema) messages = format_user_prompt(df=df, user_description=text_query) tests = run_llm(messages) print(f'Generated Tests from user input: {tests}') if isinstance(tests, Exception): tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}]) return tests, pd.DataFrame([]) tests_df = pd.DataFrame(tests) tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True) pandera_results = validate_pandera(tests, df) return tests_df, pandera_results # Custom CSS styling custom_css = """ print('Validated Tests with Pandera') .gradio-container { background-color: #f0f4f8; } .logo { max-width: 200px; margin: 20px auto; display: block; } .gr-button { background-color: #4a90e2 !important; } .gr-button:hover { background-color: #3a7bc8 !important; } """ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: gr.Image("logo.png", label=None, show_label=False, container=False, height=100) gr.Markdown("""