Spaces:
Runtime error
Runtime error
# /// script | |
# requires-python = ">=3.12" | |
# dependencies = [ | |
# "ell-ai==0.0.14", | |
# "marimo", | |
# "openai==1.53.0", | |
# "polars==1.12.0", | |
# "altair==5.4.1", | |
# ] | |
# /// | |
import marimo | |
__generated_with = "0.9.20" | |
app = marimo.App(width="medium") | |
def __(mo): | |
mo.md(r"""# Generative UI Chatbot""") | |
return | |
def __(mo): | |
_default_dataset = "hf://datasets/scikit-learn/Fish/Fish.csv" | |
dataset_input = mo.ui.text(value=_default_dataset, full_width=True) | |
return (dataset_input,) | |
def __(dataset_input, mo): | |
mo.md(f""" | |
This chatbot can answer questions about the following dataset: {dataset_input} | |
""") | |
return | |
def __(dataset_input, mo, pl): | |
# Grab a dataset | |
try: | |
df = pl.read_csv(dataset_input.value) | |
mo.output.replace( | |
mo.md(f"Loaded dataset with {len(df)} rows and {len(df.columns)} columns.") | |
) | |
except Exception as e: | |
df = pl.DataFrame() | |
mo.output.replace( | |
mo.md(f"""**Error loading dataset**:\n\n{e}""").callout(kind="danger") | |
) | |
return (df,) | |
def __(): | |
import os | |
import marimo as mo | |
import polars as pl | |
return mo, os, pl | |
def __(mo, os): | |
api_key_input = mo.ui.text( | |
label="OpenAI API Key", | |
kind="password", | |
value=os.environ.get("OPENAI_API_KEY") or "", | |
) | |
return (api_key_input,) | |
def __(api_key_input): | |
api_key_input | |
return | |
def __(api_key_input, mo): | |
from openai import Client | |
mo.stop(not api_key_input.value, mo.md("_Missing API key_")) | |
client = Client(api_key=api_key_input.value) | |
return Client, client | |
def __(df, mo): | |
import ell | |
def chart_data(x_encoding: str, y_encoding: str, color: str): | |
"""Generate an altair chart""" | |
import altair as alt | |
return ( | |
alt.Chart(df) | |
.mark_circle() | |
.encode(x=x_encoding, y=y_encoding, color=color) | |
.properties(width=500) | |
) | |
def filter_dataset(sql_query: str): | |
""" | |
Filter a polars dataframe using SQL. Please only use fields from the schema. | |
When referring to the table in SQL, call it 'data'. | |
""" | |
filtered = df.sql(sql_query, table_name="data") | |
return mo.ui.table( | |
filtered, | |
label=f"```sql\n{sql_query}\n```", | |
selection=None, | |
show_column_summaries=False, | |
) | |
return chart_data, ell, filter_dataset | |
def __(chart_data, client, df, ell, filter_dataset, mo): | |
def analyze_dataset(prompt: str) -> str: | |
"""You are a data scientist that can analyze a dataset""" | |
return f"I have a dataset with schema: {df.schema}. \n{prompt}" | |
def my_model(messages): | |
response = analyze_dataset(messages) | |
if response.tool_calls: | |
return response.tool_calls[0]() | |
return response.text | |
mo.ui.chat( | |
my_model, | |
prompts=[ | |
"Can you chart two columns of your choosing?", | |
"Can you find the min, max of all numeric fields?", | |
"What is the sum of {{column}}?", | |
], | |
) | |
return analyze_dataset, my_model | |
if __name__ == "__main__": | |
app.run() | |