File size: 2,164 Bytes
42cabf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import urllib.request
import chromadb
from chromadb.utils import embedding_functions
from datasets import load_dataset

def download_sakila_db():
    """Download Sakila SQLite database."""
    if os.path.exists("./sakila.db"):
        print("βœ“ Sakila database already exists")
        return
    
    print("Downloading Sakila database...")
    url = "https://github.com/ivanceras/sakila/raw/master/sqlite-sakila-db/sakila.db"
    urllib.request.urlretrieve(url, "./sakila.db")
    print("βœ“ Sakila database downloaded")

def setup_agnews_chromadb():
    """Load original AG News and compute embeddings."""
    print("\nLoading AG News dataset...")
    
    ds = load_dataset("fancyzhx/ag_news", split="train[:500]")
    print(f"βœ“ Loaded {len(ds)} articles")
    
    os.makedirs("./chroma_agnews/", exist_ok=True)
    client = chromadb.PersistentClient(path="./chroma_agnews/")
    
    try:
        client.delete_collection("ag_news")
    except:
        pass
    
    # Create collection with embedding function
    embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
        model_name="all-mpnet-base-v2"
    )
    
    collection = client.create_collection(
        name="ag_news",
        embedding_function=embedding_fn,
        metadata={"hnsw:space": "cosine"}
    )
    
    # Label mapping
    label_names = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
    
    # Adding to ChromaDB
    print("Computing embeddings and adding to ChromaDB...")
    
    ids = [f"doc_{i}" for i in range(len(ds))]
    documents = [item['text'] for item in ds]
    metadatas = [{
        "label": item['label'],
        "label_text": label_names[item['label']],
        "title": item['text'][:100] + "..." if len(item['text']) > 100 else item['text']
    } for item in ds]
    
    collection.add(
        ids=ids,
        documents=documents,
        metadatas=metadatas
    )
    
    print(f"βœ“ Added {len(ds)} articles to ChromaDB")

if __name__ == "__main__":
    print("=== Setting up databases ===\n")
    download_sakila_db()
    setup_agnews_chromadb()
    print("\n Setup complete! Run 'streamlit run chatbot.py'")