Spaces:
Sleeping
Sleeping
Bharath Gajula
commited on
Commit
·
42cabf2
1
Parent(s):
d0b0e7b
sadas
Browse files- Dockerfile +4 -3
- README.md +20 -17
- agents/__init__.py +6 -0
- agents/__pycache__/__init__.cpython-312.pyc +0 -0
- agents/__pycache__/rag_agent.cpython-312.pyc +0 -0
- agents/__pycache__/sql_agent.cpython-312.pyc +0 -0
- agents/rag_agent.py +116 -0
- agents/sql_agent.py +30 -0
- app.py +110 -0
- chroma_agnews/1ece7c94-65ca-4e00-a18d-81575d0bb13e/data_level0.bin +3 -0
- chroma_agnews/1ece7c94-65ca-4e00-a18d-81575d0bb13e/header.bin +3 -0
- chroma_agnews/1ece7c94-65ca-4e00-a18d-81575d0bb13e/length.bin +3 -0
- chroma_agnews/1ece7c94-65ca-4e00-a18d-81575d0bb13e/link_lists.bin +0 -0
- chroma_agnews/abe39f0b-2fae-49f8-b04a-3877bcadd8ea/data_level0.bin +3 -0
- chroma_agnews/abe39f0b-2fae-49f8-b04a-3877bcadd8ea/header.bin +3 -0
- chroma_agnews/abe39f0b-2fae-49f8-b04a-3877bcadd8ea/length.bin +3 -0
- chroma_agnews/abe39f0b-2fae-49f8-b04a-3877bcadd8ea/link_lists.bin +0 -0
- requirements.txt +14 -3
- setup_data.py +70 -0
Dockerfile
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
FROM python:3.
|
2 |
|
3 |
WORKDIR /app
|
4 |
|
@@ -10,12 +10,13 @@ RUN apt-get update && apt-get install -y \
|
|
10 |
&& rm -rf /var/lib/apt/lists/*
|
11 |
|
12 |
COPY requirements.txt ./
|
13 |
-
COPY
|
14 |
|
15 |
RUN pip3 install -r requirements.txt
|
|
|
16 |
|
17 |
EXPOSE 8501
|
18 |
|
19 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
20 |
|
21 |
-
ENTRYPOINT ["streamlit", "run", "
|
|
|
1 |
+
FROM python:3.12-slim
|
2 |
|
3 |
WORKDIR /app
|
4 |
|
|
|
10 |
&& rm -rf /var/lib/apt/lists/*
|
11 |
|
12 |
COPY requirements.txt ./
|
13 |
+
COPY . .
|
14 |
|
15 |
RUN pip3 install -r requirements.txt
|
16 |
+
RUN python setup_data.py
|
17 |
|
18 |
EXPOSE 8501
|
19 |
|
20 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
21 |
|
22 |
+
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
README.md
CHANGED
@@ -1,20 +1,23 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
colorFrom: red
|
5 |
-
colorTo: red
|
6 |
-
sdk: docker
|
7 |
-
app_port: 8501
|
8 |
-
tags:
|
9 |
-
- streamlit
|
10 |
-
pinned: false
|
11 |
-
short_description: Streamlit template space
|
12 |
-
license: apache-2.0
|
13 |
-
---
|
14 |
|
15 |
-
|
16 |
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
|
|
|
1 |
+
# Hybrid Search Chatbot
|
2 |
+
|
3 |
+
A Streamlit app for hybrid search: SQL (Chinook DB) and semantic (RAG, AG News/ChromaDB).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
## Quick Start
|
6 |
|
7 |
+
1. Install dependencies:
|
8 |
+
```bash
|
9 |
+
pip install -r requirements.txt
|
10 |
+
```
|
11 |
+
2. Initialize data:
|
12 |
+
```bash
|
13 |
+
python setup_data.py
|
14 |
+
```
|
15 |
+
3. Run the app:
|
16 |
+
```bash
|
17 |
+
streamlit run app.py
|
18 |
+
```
|
19 |
+
|
20 |
+
---
|
21 |
|
22 |
+
- Edit `.env` for API keys if needed.
|
23 |
+
- See `requirements.txt` for dependencies.
|
agents/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sql_agent import SQLAgent
|
2 |
+
from .rag_agent import RAGAgent
|
3 |
+
|
4 |
+
__all__ = ['sql_agent', 'rag_agent']
|
5 |
+
__version__ = "1.0.0"
|
6 |
+
__author__ = "Bharath Gajula"
|
agents/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (346 Bytes). View file
|
|
agents/__pycache__/rag_agent.cpython-312.pyc
ADDED
Binary file (4.78 kB). View file
|
|
agents/__pycache__/sql_agent.cpython-312.pyc
ADDED
Binary file (1.62 kB). View file
|
|
agents/rag_agent.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chromadb
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
from typing import List, Dict
|
4 |
+
import os
|
5 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
6 |
+
from langchain.schema import HumanMessage
|
7 |
+
|
8 |
+
class RAGAgent:
|
9 |
+
def __init__(self):
|
10 |
+
self.embedder = SentenceTransformer('all-mpnet-base-v2')
|
11 |
+
self.llm = ChatGoogleGenerativeAI(
|
12 |
+
model="gemini-1.5-flash",
|
13 |
+
temperature=0.3,
|
14 |
+
google_api_key=os.getenv("GOOGLE_API_KEY")
|
15 |
+
)
|
16 |
+
|
17 |
+
persist_directory = "./chroma_agnews/"
|
18 |
+
self.chroma_client = chromadb.PersistentClient(path=persist_directory)
|
19 |
+
|
20 |
+
self.collection = self.chroma_client.get_collection(name="ag_news")
|
21 |
+
print(f"Connected to ChromaDB with {self.collection.count()} documents")
|
22 |
+
|
23 |
+
def search(self, query: str, top_k: int = 5) -> Dict:
|
24 |
+
"""Search for relevant chunks and answer the question."""
|
25 |
+
# Handle empty query base case scenario
|
26 |
+
if not query or query.strip() == "":
|
27 |
+
query = "news"
|
28 |
+
|
29 |
+
# Embed the query
|
30 |
+
query_embedding = self.embedder.encode(query).tolist()
|
31 |
+
|
32 |
+
# Query the collection
|
33 |
+
results = self.collection.query(
|
34 |
+
query_embeddings=[query_embedding],
|
35 |
+
n_results=min(top_k, self.collection.count()),
|
36 |
+
include=["documents", "metadatas", "distances"]
|
37 |
+
)
|
38 |
+
|
39 |
+
# Format results
|
40 |
+
formatted_results = []
|
41 |
+
context_chunks = []
|
42 |
+
|
43 |
+
if results['ids'] and len(results['ids'][0]) > 0:
|
44 |
+
for i in range(len(results['ids'][0])):
|
45 |
+
# Calculate similarity score
|
46 |
+
distance = results['distances'][0][i] if results['distances'] else 0
|
47 |
+
similarity_score = 1 - (distance / 2)
|
48 |
+
|
49 |
+
doc_text = results['documents'][0][i]
|
50 |
+
|
51 |
+
formatted_results.append({
|
52 |
+
'text': doc_text,
|
53 |
+
'category': results['metadatas'][0][i].get('label_text', 'Unknown'),
|
54 |
+
'score': similarity_score
|
55 |
+
})
|
56 |
+
|
57 |
+
context_chunks.append(doc_text)
|
58 |
+
|
59 |
+
# Generate answer based on retrieved chunks
|
60 |
+
answer = self._generate_answer(query, context_chunks)
|
61 |
+
|
62 |
+
return {
|
63 |
+
"answer": answer,
|
64 |
+
"chunks": formatted_results,
|
65 |
+
"query": query
|
66 |
+
}
|
67 |
+
else:
|
68 |
+
return {
|
69 |
+
"answer": "No relevant information found for your question.",
|
70 |
+
"chunks": [],
|
71 |
+
"query": query
|
72 |
+
}
|
73 |
+
|
74 |
+
def _generate_answer(self, query: str, chunks: List[str]) -> str:
|
75 |
+
"""Generate answer based on retrieved chunks."""
|
76 |
+
# Combine chunks as context
|
77 |
+
context = "\n\n".join([f"[{i+1}] {chunk}" for i, chunk in enumerate(chunks)])
|
78 |
+
|
79 |
+
# Create prompt
|
80 |
+
prompt = f"""Based on the following information, answer the question.
|
81 |
+
|
82 |
+
Context:
|
83 |
+
{context}
|
84 |
+
|
85 |
+
Question: {query}
|
86 |
+
|
87 |
+
Answer:"""
|
88 |
+
|
89 |
+
# Generate answer using Gemini
|
90 |
+
response = self.llm.invoke([HumanMessage(content=prompt)])
|
91 |
+
return response.content
|
92 |
+
|
93 |
+
def get_collection_stats(self) -> Dict:
|
94 |
+
"""Get statistics about the collection."""
|
95 |
+
count = self.collection.count()
|
96 |
+
|
97 |
+
if count > 0:
|
98 |
+
sample = self.collection.get(
|
99 |
+
limit=min(100, count),
|
100 |
+
include=["metadatas"]
|
101 |
+
)
|
102 |
+
categories = {}
|
103 |
+
|
104 |
+
for metadata in sample['metadatas']:
|
105 |
+
cat = metadata.get('label_text', 'Unknown')
|
106 |
+
categories[cat] = categories.get(cat, 0) + 1
|
107 |
+
|
108 |
+
return {
|
109 |
+
"total_documents": count,
|
110 |
+
"categories": categories
|
111 |
+
}
|
112 |
+
else:
|
113 |
+
return {
|
114 |
+
"total_documents": 0,
|
115 |
+
"categories": {}
|
116 |
+
}
|
agents/sql_agent.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
2 |
+
from langchain.agents import create_sql_agent
|
3 |
+
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
|
4 |
+
from langchain.sql_database import SQLDatabase
|
5 |
+
import os
|
6 |
+
|
7 |
+
class SQLAgent:
|
8 |
+
def __init__(self, db_path: str):
|
9 |
+
self.db_path = db_path
|
10 |
+
|
11 |
+
# Create SQLDatabase instance
|
12 |
+
self.db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
|
13 |
+
|
14 |
+
self.llm = ChatGoogleGenerativeAI(
|
15 |
+
model="gemini-1.5-flash",
|
16 |
+
temperature=0,
|
17 |
+
google_api_key=os.getenv("GOOGLE_API_KEY")
|
18 |
+
)
|
19 |
+
|
20 |
+
# Create SQL toolkit and agent
|
21 |
+
toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
|
22 |
+
self.agent = create_sql_agent(
|
23 |
+
llm=self.llm,
|
24 |
+
toolkit=toolkit,
|
25 |
+
verbose=True
|
26 |
+
)
|
27 |
+
|
28 |
+
def query(self, question: str) -> str:
|
29 |
+
"""Run natural language query and return answer."""
|
30 |
+
return self.agent.run(question)
|
app.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from agents.sql_agent import SQLAgent
|
4 |
+
from agents.rag_agent import RAGAgent
|
5 |
+
import os
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
load_dotenv()
|
9 |
+
|
10 |
+
st.set_page_config(page_title="Q&A Chatbot", layout="wide")
|
11 |
+
|
12 |
+
st.title(" Q&A Chatbot")
|
13 |
+
|
14 |
+
# Initializing agents
|
15 |
+
@st.cache_resource
|
16 |
+
def init_sql_agent():
|
17 |
+
return SQLAgent("./sakila.db")
|
18 |
+
|
19 |
+
@st.cache_resource
|
20 |
+
def init_rag_agent():
|
21 |
+
return RAGAgent()
|
22 |
+
|
23 |
+
mode = st.sidebar.radio(
|
24 |
+
"Select Mode:",
|
25 |
+
["Movie Database (SQL)", "News Search (RAG)"]
|
26 |
+
)
|
27 |
+
|
28 |
+
st.markdown("---")
|
29 |
+
|
30 |
+
if mode == "Movie Database (SQL)":
|
31 |
+
st.subheader(" Movie Database (SQL)")
|
32 |
+
|
33 |
+
sql_question = st.text_input("Ask about movies:", placeholder="Please enter your nlp sql question here ", key="sql_input")
|
34 |
+
|
35 |
+
if sql_question:
|
36 |
+
sql_agent = init_sql_agent()
|
37 |
+
|
38 |
+
with st.spinner("Querying database..."):
|
39 |
+
answer = sql_agent.query(sql_question)
|
40 |
+
|
41 |
+
# Display answer
|
42 |
+
st.markdown("### Answer")
|
43 |
+
st.write(answer)
|
44 |
+
|
45 |
+
else:
|
46 |
+
st.subheader(" News Search (RAG)")
|
47 |
+
|
48 |
+
# RAG input
|
49 |
+
rag_question = st.text_input("Ask about news:", placeholder="What's happening around the world?", key="rag_input")
|
50 |
+
|
51 |
+
if rag_question:
|
52 |
+
rag_agent = init_rag_agent()
|
53 |
+
|
54 |
+
with st.spinner("Searching news..."):
|
55 |
+
result = rag_agent.search(rag_question)
|
56 |
+
|
57 |
+
st.markdown("### Answer")
|
58 |
+
st.info(result['answer'])
|
59 |
+
|
60 |
+
# Sources section with full chunks
|
61 |
+
st.markdown("### Source Articles")
|
62 |
+
|
63 |
+
for j, chunk in enumerate(result['chunks']):
|
64 |
+
with st.container():
|
65 |
+
col1, col2 = st.columns([3, 1])
|
66 |
+
|
67 |
+
with col1:
|
68 |
+
st.markdown(f"**Article {j+1}**")
|
69 |
+
with col2:
|
70 |
+
st.markdown(f"**{chunk['category']}**")
|
71 |
+
|
72 |
+
# Full text in a text area for better readability
|
73 |
+
st.text_area(
|
74 |
+
label="",
|
75 |
+
value=chunk['text'],
|
76 |
+
height=150,
|
77 |
+
disabled=True,
|
78 |
+
key=f"chunk_{j}"
|
79 |
+
)
|
80 |
+
|
81 |
+
# Score if available
|
82 |
+
if chunk.get('score', 0) > 0:
|
83 |
+
st.caption(f"Relevance Score: {chunk['score']:.1%}")
|
84 |
+
|
85 |
+
st.markdown("---")
|
86 |
+
|
87 |
+
# Sidebar
|
88 |
+
st.sidebar.markdown("---")
|
89 |
+
st.sidebar.markdown("### Example Questions")
|
90 |
+
|
91 |
+
if mode == "Movie Database (SQL)":
|
92 |
+
st.sidebar.markdown("""
|
93 |
+
- How many films are there?
|
94 |
+
- Show me the top 5 longest films
|
95 |
+
- Which actors have the most films?
|
96 |
+
- List all film categories
|
97 |
+
- How many customers do we have?
|
98 |
+
""")
|
99 |
+
else:
|
100 |
+
st.sidebar.markdown("""
|
101 |
+
- What's happening with oil prices?
|
102 |
+
- Tell me about technology news
|
103 |
+
- Any sports updates?
|
104 |
+
- Business news today
|
105 |
+
- Science discoveries
|
106 |
+
""")
|
107 |
+
|
108 |
+
st.sidebar.markdown("---")
|
109 |
+
st.sidebar.caption("Created by Bharath Gajula")
|
110 |
+
st.sidebar.caption("Powered by Gemini & LangChain")
|
chroma_agnews/1ece7c94-65ca-4e00-a18d-81575d0bb13e/data_level0.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ebdb4a62fc9c29c5f41ec836bd8856e9526cf002440601ca5f2fed121cb0696c
|
3 |
+
size 32120000
|
chroma_agnews/1ece7c94-65ca-4e00-a18d-81575d0bb13e/header.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8c7f00b4415698ee6cb94332eff91aedc06ba8e066b1f200e78ca5df51abb57
|
3 |
+
size 100
|
chroma_agnews/1ece7c94-65ca-4e00-a18d-81575d0bb13e/length.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:353eeae34f121621a657ab9bf30b59b722029f4415aff8a48f1466ef39bc7211
|
3 |
+
size 40000
|
chroma_agnews/1ece7c94-65ca-4e00-a18d-81575d0bb13e/link_lists.bin
ADDED
File without changes
|
chroma_agnews/abe39f0b-2fae-49f8-b04a-3877bcadd8ea/data_level0.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:23add52afbe7588391f32d3deffb581b2663d2e2ad8851aba7de25e6b3f66761
|
3 |
+
size 32120000
|
chroma_agnews/abe39f0b-2fae-49f8-b04a-3877bcadd8ea/header.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8c7f00b4415698ee6cb94332eff91aedc06ba8e066b1f200e78ca5df51abb57
|
3 |
+
size 100
|
chroma_agnews/abe39f0b-2fae-49f8-b04a-3877bcadd8ea/length.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7e2dcff542de95352682dc186432e98f0188084896773f1973276b0577d5305
|
3 |
+
size 40000
|
chroma_agnews/abe39f0b-2fae-49f8-b04a-3877bcadd8ea/link_lists.bin
ADDED
File without changes
|
requirements.txt
CHANGED
@@ -1,3 +1,14 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit>=1.28.0
|
2 |
+
langchain>=0.3.0
|
3 |
+
langchain-community>=0.3.0
|
4 |
+
langchain-google-genai>=1.0.0
|
5 |
+
google-generativeai>=0.3.0
|
6 |
+
chromadb>=0.4.0
|
7 |
+
sentence-transformers>=2.2.0
|
8 |
+
pandas>=2.0.0
|
9 |
+
python-dotenv>=1.0.0
|
10 |
+
pytest>=7.4.0
|
11 |
+
numpy>=1.24.0
|
12 |
+
sqlalchemy>=2.0.0
|
13 |
+
datasets>=2.14.0
|
14 |
+
requests>=2.31.0
|
setup_data.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import urllib.request
|
3 |
+
import chromadb
|
4 |
+
from chromadb.utils import embedding_functions
|
5 |
+
from datasets import load_dataset
|
6 |
+
|
7 |
+
def download_sakila_db():
|
8 |
+
"""Download Sakila SQLite database."""
|
9 |
+
if os.path.exists("./sakila.db"):
|
10 |
+
print("✓ Sakila database already exists")
|
11 |
+
return
|
12 |
+
|
13 |
+
print("Downloading Sakila database...")
|
14 |
+
url = "https://github.com/ivanceras/sakila/raw/master/sqlite-sakila-db/sakila.db"
|
15 |
+
urllib.request.urlretrieve(url, "./sakila.db")
|
16 |
+
print("✓ Sakila database downloaded")
|
17 |
+
|
18 |
+
def setup_agnews_chromadb():
|
19 |
+
"""Load original AG News and compute embeddings."""
|
20 |
+
print("\nLoading AG News dataset...")
|
21 |
+
|
22 |
+
ds = load_dataset("fancyzhx/ag_news", split="train[:500]")
|
23 |
+
print(f"✓ Loaded {len(ds)} articles")
|
24 |
+
|
25 |
+
os.makedirs("./chroma_agnews/", exist_ok=True)
|
26 |
+
client = chromadb.PersistentClient(path="./chroma_agnews/")
|
27 |
+
|
28 |
+
try:
|
29 |
+
client.delete_collection("ag_news")
|
30 |
+
except:
|
31 |
+
pass
|
32 |
+
|
33 |
+
# Create collection with embedding function
|
34 |
+
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
|
35 |
+
model_name="all-mpnet-base-v2"
|
36 |
+
)
|
37 |
+
|
38 |
+
collection = client.create_collection(
|
39 |
+
name="ag_news",
|
40 |
+
embedding_function=embedding_fn,
|
41 |
+
metadata={"hnsw:space": "cosine"}
|
42 |
+
)
|
43 |
+
|
44 |
+
# Label mapping
|
45 |
+
label_names = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
|
46 |
+
|
47 |
+
# Adding to ChromaDB
|
48 |
+
print("Computing embeddings and adding to ChromaDB...")
|
49 |
+
|
50 |
+
ids = [f"doc_{i}" for i in range(len(ds))]
|
51 |
+
documents = [item['text'] for item in ds]
|
52 |
+
metadatas = [{
|
53 |
+
"label": item['label'],
|
54 |
+
"label_text": label_names[item['label']],
|
55 |
+
"title": item['text'][:100] + "..." if len(item['text']) > 100 else item['text']
|
56 |
+
} for item in ds]
|
57 |
+
|
58 |
+
collection.add(
|
59 |
+
ids=ids,
|
60 |
+
documents=documents,
|
61 |
+
metadatas=metadatas
|
62 |
+
)
|
63 |
+
|
64 |
+
print(f"✓ Added {len(ds)} articles to ChromaDB")
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
print("=== Setting up databases ===\n")
|
68 |
+
download_sakila_db()
|
69 |
+
setup_agnews_chromadb()
|
70 |
+
print("\n Setup complete! Run 'streamlit run chatbot.py'")
|