|
import os |
|
|
|
import streamlit as st |
|
from st_app import launch_bot |
|
|
|
import nest_asyncio |
|
import asyncio |
|
import uuid |
|
|
|
import sqlite3 |
|
from datasets import load_dataset |
|
|
|
|
|
if 'device_id' not in st.session_state: |
|
st.session_state.device_id = str(uuid.uuid4()) |
|
|
|
if "feedback_key" not in st.session_state: |
|
st.session_state.feedback_key = 0 |
|
|
|
def setup_db(): |
|
db_path = 'cfpb_database.db' |
|
conn = sqlite3.connect(db_path) |
|
cursor = conn.cursor() |
|
|
|
with st.spinner("Loading data... Please wait..."): |
|
def table_populated() -> bool: |
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='cfpb_complaints'") |
|
result = cursor.fetchone() |
|
if not result: |
|
return False |
|
return True |
|
|
|
if table_populated(): |
|
print("Database table already populated, skipping setup") |
|
conn.close() |
|
return |
|
else: |
|
print("Populating database table") |
|
|
|
|
|
with open('create_table.sql', 'r') as sql_file: |
|
sql_script = sql_file.read() |
|
cursor.executescript(sql_script) |
|
|
|
hf_token = os.getenv('HF_TOKEN') |
|
|
|
|
|
df = load_dataset("vectara/cfpb-complaints", data_files="cfpb_complaints.csv", token=hf_token)['train'].to_pandas() |
|
df.to_sql('cfpb_complaints', conn, if_exists='replace', index=False) |
|
|
|
|
|
conn.commit() |
|
conn.close() |
|
|
|
if __name__ == "__main__": |
|
st.set_page_config(page_title="CFPB Complaints Assistant", layout="wide") |
|
setup_db() |
|
|
|
nest_asyncio.apply() |
|
asyncio.run(launch_bot()) |
|
|