Spaces:
Sleeping
Sleeping
File size: 2,164 Bytes
42cabf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import os
import urllib.request
import chromadb
from chromadb.utils import embedding_functions
from datasets import load_dataset
def download_sakila_db():
"""Download Sakila SQLite database."""
if os.path.exists("./sakila.db"):
print("β Sakila database already exists")
return
print("Downloading Sakila database...")
url = "https://github.com/ivanceras/sakila/raw/master/sqlite-sakila-db/sakila.db"
urllib.request.urlretrieve(url, "./sakila.db")
print("β Sakila database downloaded")
def setup_agnews_chromadb():
"""Load original AG News and compute embeddings."""
print("\nLoading AG News dataset...")
ds = load_dataset("fancyzhx/ag_news", split="train[:500]")
print(f"β Loaded {len(ds)} articles")
os.makedirs("./chroma_agnews/", exist_ok=True)
client = chromadb.PersistentClient(path="./chroma_agnews/")
try:
client.delete_collection("ag_news")
except:
pass
# Create collection with embedding function
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-mpnet-base-v2"
)
collection = client.create_collection(
name="ag_news",
embedding_function=embedding_fn,
metadata={"hnsw:space": "cosine"}
)
# Label mapping
label_names = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
# Adding to ChromaDB
print("Computing embeddings and adding to ChromaDB...")
ids = [f"doc_{i}" for i in range(len(ds))]
documents = [item['text'] for item in ds]
metadatas = [{
"label": item['label'],
"label_text": label_names[item['label']],
"title": item['text'][:100] + "..." if len(item['text']) > 100 else item['text']
} for item in ds]
collection.add(
ids=ids,
documents=documents,
metadatas=metadatas
)
print(f"β Added {len(ds)} articles to ChromaDB")
if __name__ == "__main__":
print("=== Setting up databases ===\n")
download_sakila_db()
setup_agnews_chromadb()
print("\n Setup complete! Run 'streamlit run chatbot.py'") |