Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import altair as alt | |
import pydeck as pdk | |
import random | |
from pytz import country_names | |
from st_aggrid import AgGrid, GridUpdateMode, JsCode | |
from st_aggrid.grid_options_builder import GridOptionsBuilder | |
import snowflake.connector | |
from snowflake.connector.pandas_tools import write_pandas | |
from snowflake.connector import connect | |
def load_data(): | |
df = pd.read_csv("country-list.csv") | |
return df | |
def convert_df(df): | |
# IMPORTANT: Cache the conversion to prevent computation on every rerun | |
return df.to_csv().encode("utf-8") | |
def execute_query(conn, df_sel_row, table_name): | |
if not df_sel_row.empty: | |
conn.cursor().execute( | |
"CREATE OR REPLACE TABLE " | |
f"{table_name}(COUNTRY string, CAPITAL string, TYPE string)" | |
) | |
write_pandas( | |
conn=conn, | |
df=df_sel_row, | |
table_name=table_name, | |
database="STREAMLIT_DB", | |
schema="PUBLIC", | |
quote_identifiers=False, | |
) | |
# Initialize connection. | |
# Uses st.experimental_singleton to only run once. | |
def init_connection(): | |
return snowflake.connector.connect(**st.secrets["snowflake"]) | |
# The code below is for the title and logo. | |
st.set_page_config(page_title="Dataframe with editable cells", page_icon="💾") | |
st.image( | |
"https://emojipedia-us.s3.dualstack.us-west-1.amazonaws.com/thumbs/240/apple/325/floppy-disk_1f4be.png", | |
width=100, | |
) | |
conn = init_connection() | |
df = load_data() | |
st.title("Dataframe with editable cells") | |
st.write("") | |
st.markdown( | |
"""This is a demo of a dataframe with editable cells, powered by | |
[streamlit-aggrid](https://pypi.org/project/streamlit-aggrid/). | |
You can edit the cells by clicking on them, and then export | |
your selection to a `.csv` file (or send it to your Snowflake DB!)""" | |
) | |
st.write("") | |
st.write("") | |
st.subheader("① Select and edit cells") | |
st.info("💡 Hold the `Shift` (⇧) key to select multiple rows at once.") | |
st.caption("") | |
gd = GridOptionsBuilder.from_dataframe(df) | |
gd.configure_pagination(enabled=True) | |
gd.configure_default_column(editable=True, groupable=True) | |
gd.configure_selection(selection_mode="multiple", use_checkbox=True) | |
gridoptions = gd.build() | |
grid_table = AgGrid( | |
df, | |
gridOptions=gridoptions, | |
update_mode=GridUpdateMode.SELECTION_CHANGED, | |
theme="material", | |
) | |
sel_row = grid_table["selected_rows"] | |
st.subheader(" ② Check your selection") | |
st.write("") | |
df_sel_row = pd.DataFrame(sel_row) | |
csv = convert_df(df_sel_row) | |
if not df_sel_row.empty: | |
st.write(df_sel_row) | |
st.download_button( | |
label="Download to CSV", | |
data=csv, | |
file_name="results.csv", | |
mime="text/csv", | |
) | |
st.write("") | |
st.write("") | |
st.subheader("③ Send to Snowflake DB ❄️") | |
st.write("") | |
table_name = st.text_input("Pick a table name", "YOUR_TABLE_NAME_HERE", help="No spaces allowed") | |
run_query = st.button( | |
"Add to DB", on_click=execute_query, args=(conn, df_sel_row, table_name) | |
) | |
if run_query and not df_sel_row.empty: | |
st.success( | |
f"✔️ Selection added to the `{table_name}` table located in the `STREAMLIT_DB` database." | |
) | |
st.snow() | |
if run_query and df_sel_row.empty: | |
st.info("Nothing to add to DB, please select some rows") | |
# callback to update query param on selectbox change | |
def update_params(): | |
st.experimental_set_query_params(option=st.session_state.qp) | |
options = ["cat", "dog", "mouse", "bat", "duck"] | |
query_params = st.experimental_get_query_params() | |
# set selectbox value based on query param, or provide a default | |
ix = 0 | |
if query_params: | |
try: | |
ix = options.index(query_params['option'][0]) | |
except ValueError: | |
pass | |
selected_option = st.radio( | |
"Param", options, index=ix, key="qp", on_change=update_params | |
) | |
# set query param based on selection | |
st.experimental_set_query_params(option=selected_option) | |
# display for debugging purposes | |
st.write('---', st.experimental_get_query_params()) | |
# SETTING PAGE CONFIG TO WIDE MODE AND ADDING A TITLE AND FAVICON | |
#st.set_page_config(layout="wide", page_title="NYC Ridesharing Demo", page_icon=":taxi:") | |
# LOAD DATA ONCE | |
def load_data(): | |
data = pd.read_csv( | |
"./uber-raw-data-sep14.csv.gz", | |
nrows=100000, # approx. 10% of data | |
names=[ | |
"date/time", | |
"lat", | |
"lon", | |
], # specify names directly since they don't change | |
skiprows=1, # don't read header since names specified directly | |
usecols=[0, 1, 2], # doesn't load last column, constant value "B02512" | |
parse_dates=[ | |
"date/time" | |
], # set as datetime instead of converting after the fact | |
) | |
return data | |
# FUNCTION FOR AIRPORT MAPS | |
def map(data, lat, lon, zoom): | |
st.write( | |
pdk.Deck( | |
map_style="mapbox://styles/mapbox/light-v9", | |
initial_view_state={ | |
"latitude": lat, | |
"longitude": lon, | |
"zoom": zoom, | |
"pitch": 50, | |
}, | |
layers=[ | |
pdk.Layer( | |
"HexagonLayer", | |
data=data, | |
get_position=["lon", "lat"], | |
radius=100, | |
elevation_scale=4, | |
elevation_range=[0, 1000], | |
pickable=True, | |
extruded=True, | |
), | |
], | |
) | |
) | |
# FILTER DATA FOR A SPECIFIC HOUR, CACHE | |
def filterdata(df, hour_selected): | |
return df[df["date/time"].dt.hour == hour_selected] | |
# CALCULATE MIDPOINT FOR GIVEN SET OF DATA | |
def mpoint(lat, lon): | |
return (np.average(lat), np.average(lon)) | |
# FILTER DATA BY HOUR | |
def histdata(df, hr): | |
filtered = data[ | |
(df["date/time"].dt.hour >= hr) & (df["date/time"].dt.hour < (hr + 1)) | |
] | |
hist = np.histogram(filtered["date/time"].dt.minute, bins=60, range=(0, 60))[0] | |
return pd.DataFrame({"minute": range(60), "pickups": hist}) | |
# STREAMLIT APP LAYOUT | |
data = load_data() | |
# LAYING OUT THE TOP SECTION OF THE APP | |
row1_1, row1_2 = st.columns((2, 3)) | |
with row1_1: | |
st.title("NYC Uber Ridesharing Data") | |
hour_selected = st.slider("Select hour of pickup", 0, 23) | |
with row1_2: | |
st.write( | |
""" | |
## | |
Examining how Uber pickups vary over time in New York City's and at its major regional airports. | |
By sliding the slider on the left you can view different slices of time and explore different transportation trends. | |
""" | |
) | |
# LAYING OUT THE MIDDLE SECTION OF THE APP WITH THE MAPS | |
row2_1, row2_2, row2_3, row2_4 = st.columns((2, 1, 1, 1)) | |
# SETTING THE ZOOM LOCATIONS FOR THE AIRPORTS | |
la_guardia = [40.7900, -73.8700] | |
jfk = [40.6650, -73.7821] | |
newark = [40.7090, -74.1805] | |
zoom_level = 12 | |
midpoint = mpoint(data["lat"], data["lon"]) | |
with row2_1: | |
st.write( | |
f"""**All New York City from {hour_selected}:00 and {(hour_selected + 1) % 24}:00**""" | |
) | |
map(filterdata(data, hour_selected), midpoint[0], midpoint[1], 11) | |
with row2_2: | |
st.write("**La Guardia Airport**") | |
map(filterdata(data, hour_selected), la_guardia[0], la_guardia[1], zoom_level) | |
with row2_3: | |
st.write("**JFK Airport**") | |
map(filterdata(data, hour_selected), jfk[0], jfk[1], zoom_level) | |
with row2_4: | |
st.write("**Newark Airport**") | |
map(filterdata(data, hour_selected), newark[0], newark[1], zoom_level) | |
# CALCULATING DATA FOR THE HISTOGRAM | |
chart_data = histdata(data, hour_selected) | |
# LAYING OUT THE HISTOGRAM SECTION | |
st.write( | |
f"""**Breakdown of rides per minute between {hour_selected}:00 and {(hour_selected + 1) % 24}:00**""" | |
) | |
st.altair_chart( | |
alt.Chart(chart_data) | |
.mark_area( | |
interpolate="step-after", | |
) | |
.encode( | |
x=alt.X("minute:Q", scale=alt.Scale(nice=False)), | |
y=alt.Y("pickups:Q"), | |
tooltip=["minute", "pickups"], | |
) | |
.configure_mark(opacity=0.2, color="red"), | |
use_container_width=True, | |
) | |
def foo(x): | |
return x**2 | |
if st.button("Clear Foo"): | |
# Clear foo's memoized values: | |
foo.clear() | |
if st.button("Clear All"): | |
# Clear values from *all* memoized functions: | |
st.experimental_memo.clear() |