import os import urllib.request import chromadb from chromadb.utils import embedding_functions from datasets import load_dataset def download_sakila_db(): """Download Sakila SQLite database.""" if os.path.exists("./sakila.db"): print("✓ Sakila database already exists") return print("Downloading Sakila database...") url = "https://github.com/ivanceras/sakila/raw/master/sqlite-sakila-db/sakila.db" urllib.request.urlretrieve(url, "./sakila.db") print("✓ Sakila database downloaded") def setup_agnews_chromadb(): """Load original AG News and compute embeddings.""" print("\nLoading AG News dataset...") ds = load_dataset("fancyzhx/ag_news", split="train[:500]") print(f"✓ Loaded {len(ds)} articles") os.makedirs("./chroma_agnews/", exist_ok=True) client = chromadb.PersistentClient(path="./chroma_agnews/") try: client.delete_collection("ag_news") except: pass # Create collection with embedding function embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction( model_name="all-mpnet-base-v2" ) collection = client.create_collection( name="ag_news", embedding_function=embedding_fn, metadata={"hnsw:space": "cosine"} ) # Label mapping label_names = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} # Adding to ChromaDB print("Computing embeddings and adding to ChromaDB...") ids = [f"doc_{i}" for i in range(len(ds))] documents = [item['text'] for item in ds] metadatas = [{ "label": item['label'], "label_text": label_names[item['label']], "title": item['text'][:100] + "..." if len(item['text']) > 100 else item['text'] } for item in ds] collection.add( ids=ids, documents=documents, metadatas=metadatas ) print(f"✓ Added {len(ds)} articles to ChromaDB") if __name__ == "__main__": print("=== Setting up databases ===\n") download_sakila_db() setup_agnews_chromadb() print("\n Setup complete! Run 'streamlit run chatbot.py'")