Bharath Gajula commited on
Commit
42cabf2
·
1 Parent(s): d0b0e7b
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM python:3.9-slim
2
 
3
  WORKDIR /app
4
 
@@ -10,12 +10,13 @@ RUN apt-get update && apt-get install -y \
10
  && rm -rf /var/lib/apt/lists/*
11
 
12
  COPY requirements.txt ./
13
- COPY src/ ./src/
14
 
15
  RUN pip3 install -r requirements.txt
 
16
 
17
  EXPOSE 8501
18
 
19
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
 
21
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
1
+ FROM python:3.12-slim
2
 
3
  WORKDIR /app
4
 
 
10
  && rm -rf /var/lib/apt/lists/*
11
 
12
  COPY requirements.txt ./
13
+ COPY . .
14
 
15
  RUN pip3 install -r requirements.txt
16
+ RUN python setup_data.py
17
 
18
  EXPOSE 8501
19
 
20
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
21
 
22
+ ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
README.md CHANGED
@@ -1,20 +1,23 @@
1
- ---
2
- title: Testing Rag
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: Streamlit template space
12
- license: apache-2.0
13
- ---
14
 
15
- # Welcome to Streamlit!
16
 
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
1
+ # Hybrid Search Chatbot
2
+
3
+ A Streamlit app for hybrid search: SQL (Chinook DB) and semantic (RAG, AG News/ChromaDB).
 
 
 
 
 
 
 
 
 
 
4
 
5
+ ## Quick Start
6
 
7
+ 1. Install dependencies:
8
+ ```bash
9
+ pip install -r requirements.txt
10
+ ```
11
+ 2. Initialize data:
12
+ ```bash
13
+ python setup_data.py
14
+ ```
15
+ 3. Run the app:
16
+ ```bash
17
+ streamlit run app.py
18
+ ```
19
+
20
+ ---
21
 
22
+ - Edit `.env` for API keys if needed.
23
+ - See `requirements.txt` for dependencies.
agents/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .sql_agent import SQLAgent
2
+ from .rag_agent import RAGAgent
3
+
4
+ __all__ = ['sql_agent', 'rag_agent']
5
+ __version__ = "1.0.0"
6
+ __author__ = "Bharath Gajula"
agents/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (346 Bytes). View file
 
agents/__pycache__/rag_agent.cpython-312.pyc ADDED
Binary file (4.78 kB). View file
 
agents/__pycache__/sql_agent.cpython-312.pyc ADDED
Binary file (1.62 kB). View file
 
agents/rag_agent.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ from sentence_transformers import SentenceTransformer
3
+ from typing import List, Dict
4
+ import os
5
+ from langchain_google_genai import ChatGoogleGenerativeAI
6
+ from langchain.schema import HumanMessage
7
+
8
+ class RAGAgent:
9
+ def __init__(self):
10
+ self.embedder = SentenceTransformer('all-mpnet-base-v2')
11
+ self.llm = ChatGoogleGenerativeAI(
12
+ model="gemini-1.5-flash",
13
+ temperature=0.3,
14
+ google_api_key=os.getenv("GOOGLE_API_KEY")
15
+ )
16
+
17
+ persist_directory = "./chroma_agnews/"
18
+ self.chroma_client = chromadb.PersistentClient(path=persist_directory)
19
+
20
+ self.collection = self.chroma_client.get_collection(name="ag_news")
21
+ print(f"Connected to ChromaDB with {self.collection.count()} documents")
22
+
23
+ def search(self, query: str, top_k: int = 5) -> Dict:
24
+ """Search for relevant chunks and answer the question."""
25
+ # Handle empty query base case scenario
26
+ if not query or query.strip() == "":
27
+ query = "news"
28
+
29
+ # Embed the query
30
+ query_embedding = self.embedder.encode(query).tolist()
31
+
32
+ # Query the collection
33
+ results = self.collection.query(
34
+ query_embeddings=[query_embedding],
35
+ n_results=min(top_k, self.collection.count()),
36
+ include=["documents", "metadatas", "distances"]
37
+ )
38
+
39
+ # Format results
40
+ formatted_results = []
41
+ context_chunks = []
42
+
43
+ if results['ids'] and len(results['ids'][0]) > 0:
44
+ for i in range(len(results['ids'][0])):
45
+ # Calculate similarity score
46
+ distance = results['distances'][0][i] if results['distances'] else 0
47
+ similarity_score = 1 - (distance / 2)
48
+
49
+ doc_text = results['documents'][0][i]
50
+
51
+ formatted_results.append({
52
+ 'text': doc_text,
53
+ 'category': results['metadatas'][0][i].get('label_text', 'Unknown'),
54
+ 'score': similarity_score
55
+ })
56
+
57
+ context_chunks.append(doc_text)
58
+
59
+ # Generate answer based on retrieved chunks
60
+ answer = self._generate_answer(query, context_chunks)
61
+
62
+ return {
63
+ "answer": answer,
64
+ "chunks": formatted_results,
65
+ "query": query
66
+ }
67
+ else:
68
+ return {
69
+ "answer": "No relevant information found for your question.",
70
+ "chunks": [],
71
+ "query": query
72
+ }
73
+
74
+ def _generate_answer(self, query: str, chunks: List[str]) -> str:
75
+ """Generate answer based on retrieved chunks."""
76
+ # Combine chunks as context
77
+ context = "\n\n".join([f"[{i+1}] {chunk}" for i, chunk in enumerate(chunks)])
78
+
79
+ # Create prompt
80
+ prompt = f"""Based on the following information, answer the question.
81
+
82
+ Context:
83
+ {context}
84
+
85
+ Question: {query}
86
+
87
+ Answer:"""
88
+
89
+ # Generate answer using Gemini
90
+ response = self.llm.invoke([HumanMessage(content=prompt)])
91
+ return response.content
92
+
93
+ def get_collection_stats(self) -> Dict:
94
+ """Get statistics about the collection."""
95
+ count = self.collection.count()
96
+
97
+ if count > 0:
98
+ sample = self.collection.get(
99
+ limit=min(100, count),
100
+ include=["metadatas"]
101
+ )
102
+ categories = {}
103
+
104
+ for metadata in sample['metadatas']:
105
+ cat = metadata.get('label_text', 'Unknown')
106
+ categories[cat] = categories.get(cat, 0) + 1
107
+
108
+ return {
109
+ "total_documents": count,
110
+ "categories": categories
111
+ }
112
+ else:
113
+ return {
114
+ "total_documents": 0,
115
+ "categories": {}
116
+ }
agents/sql_agent.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_google_genai import ChatGoogleGenerativeAI
2
+ from langchain.agents import create_sql_agent
3
+ from langchain.agents.agent_toolkits import SQLDatabaseToolkit
4
+ from langchain.sql_database import SQLDatabase
5
+ import os
6
+
7
+ class SQLAgent:
8
+ def __init__(self, db_path: str):
9
+ self.db_path = db_path
10
+
11
+ # Create SQLDatabase instance
12
+ self.db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
13
+
14
+ self.llm = ChatGoogleGenerativeAI(
15
+ model="gemini-1.5-flash",
16
+ temperature=0,
17
+ google_api_key=os.getenv("GOOGLE_API_KEY")
18
+ )
19
+
20
+ # Create SQL toolkit and agent
21
+ toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
22
+ self.agent = create_sql_agent(
23
+ llm=self.llm,
24
+ toolkit=toolkit,
25
+ verbose=True
26
+ )
27
+
28
+ def query(self, question: str) -> str:
29
+ """Run natural language query and return answer."""
30
+ return self.agent.run(question)
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from agents.sql_agent import SQLAgent
4
+ from agents.rag_agent import RAGAgent
5
+ import os
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+
10
+ st.set_page_config(page_title="Q&A Chatbot", layout="wide")
11
+
12
+ st.title(" Q&A Chatbot")
13
+
14
+ # Initializing agents
15
+ @st.cache_resource
16
+ def init_sql_agent():
17
+ return SQLAgent("./sakila.db")
18
+
19
+ @st.cache_resource
20
+ def init_rag_agent():
21
+ return RAGAgent()
22
+
23
+ mode = st.sidebar.radio(
24
+ "Select Mode:",
25
+ ["Movie Database (SQL)", "News Search (RAG)"]
26
+ )
27
+
28
+ st.markdown("---")
29
+
30
+ if mode == "Movie Database (SQL)":
31
+ st.subheader(" Movie Database (SQL)")
32
+
33
+ sql_question = st.text_input("Ask about movies:", placeholder="Please enter your nlp sql question here ", key="sql_input")
34
+
35
+ if sql_question:
36
+ sql_agent = init_sql_agent()
37
+
38
+ with st.spinner("Querying database..."):
39
+ answer = sql_agent.query(sql_question)
40
+
41
+ # Display answer
42
+ st.markdown("### Answer")
43
+ st.write(answer)
44
+
45
+ else:
46
+ st.subheader(" News Search (RAG)")
47
+
48
+ # RAG input
49
+ rag_question = st.text_input("Ask about news:", placeholder="What's happening around the world?", key="rag_input")
50
+
51
+ if rag_question:
52
+ rag_agent = init_rag_agent()
53
+
54
+ with st.spinner("Searching news..."):
55
+ result = rag_agent.search(rag_question)
56
+
57
+ st.markdown("### Answer")
58
+ st.info(result['answer'])
59
+
60
+ # Sources section with full chunks
61
+ st.markdown("### Source Articles")
62
+
63
+ for j, chunk in enumerate(result['chunks']):
64
+ with st.container():
65
+ col1, col2 = st.columns([3, 1])
66
+
67
+ with col1:
68
+ st.markdown(f"**Article {j+1}**")
69
+ with col2:
70
+ st.markdown(f"**{chunk['category']}**")
71
+
72
+ # Full text in a text area for better readability
73
+ st.text_area(
74
+ label="",
75
+ value=chunk['text'],
76
+ height=150,
77
+ disabled=True,
78
+ key=f"chunk_{j}"
79
+ )
80
+
81
+ # Score if available
82
+ if chunk.get('score', 0) > 0:
83
+ st.caption(f"Relevance Score: {chunk['score']:.1%}")
84
+
85
+ st.markdown("---")
86
+
87
+ # Sidebar
88
+ st.sidebar.markdown("---")
89
+ st.sidebar.markdown("### Example Questions")
90
+
91
+ if mode == "Movie Database (SQL)":
92
+ st.sidebar.markdown("""
93
+ - How many films are there?
94
+ - Show me the top 5 longest films
95
+ - Which actors have the most films?
96
+ - List all film categories
97
+ - How many customers do we have?
98
+ """)
99
+ else:
100
+ st.sidebar.markdown("""
101
+ - What's happening with oil prices?
102
+ - Tell me about technology news
103
+ - Any sports updates?
104
+ - Business news today
105
+ - Science discoveries
106
+ """)
107
+
108
+ st.sidebar.markdown("---")
109
+ st.sidebar.caption("Created by Bharath Gajula")
110
+ st.sidebar.caption("Powered by Gemini & LangChain")
chroma_agnews/1ece7c94-65ca-4e00-a18d-81575d0bb13e/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebdb4a62fc9c29c5f41ec836bd8856e9526cf002440601ca5f2fed121cb0696c
3
+ size 32120000
chroma_agnews/1ece7c94-65ca-4e00-a18d-81575d0bb13e/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8c7f00b4415698ee6cb94332eff91aedc06ba8e066b1f200e78ca5df51abb57
3
+ size 100
chroma_agnews/1ece7c94-65ca-4e00-a18d-81575d0bb13e/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:353eeae34f121621a657ab9bf30b59b722029f4415aff8a48f1466ef39bc7211
3
+ size 40000
chroma_agnews/1ece7c94-65ca-4e00-a18d-81575d0bb13e/link_lists.bin ADDED
File without changes
chroma_agnews/abe39f0b-2fae-49f8-b04a-3877bcadd8ea/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23add52afbe7588391f32d3deffb581b2663d2e2ad8851aba7de25e6b3f66761
3
+ size 32120000
chroma_agnews/abe39f0b-2fae-49f8-b04a-3877bcadd8ea/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8c7f00b4415698ee6cb94332eff91aedc06ba8e066b1f200e78ca5df51abb57
3
+ size 100
chroma_agnews/abe39f0b-2fae-49f8-b04a-3877bcadd8ea/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7e2dcff542de95352682dc186432e98f0188084896773f1973276b0577d5305
3
+ size 40000
chroma_agnews/abe39f0b-2fae-49f8-b04a-3877bcadd8ea/link_lists.bin ADDED
File without changes
requirements.txt CHANGED
@@ -1,3 +1,14 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.28.0
2
+ langchain>=0.3.0
3
+ langchain-community>=0.3.0
4
+ langchain-google-genai>=1.0.0
5
+ google-generativeai>=0.3.0
6
+ chromadb>=0.4.0
7
+ sentence-transformers>=2.2.0
8
+ pandas>=2.0.0
9
+ python-dotenv>=1.0.0
10
+ pytest>=7.4.0
11
+ numpy>=1.24.0
12
+ sqlalchemy>=2.0.0
13
+ datasets>=2.14.0
14
+ requests>=2.31.0
setup_data.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib.request
3
+ import chromadb
4
+ from chromadb.utils import embedding_functions
5
+ from datasets import load_dataset
6
+
7
+ def download_sakila_db():
8
+ """Download Sakila SQLite database."""
9
+ if os.path.exists("./sakila.db"):
10
+ print("✓ Sakila database already exists")
11
+ return
12
+
13
+ print("Downloading Sakila database...")
14
+ url = "https://github.com/ivanceras/sakila/raw/master/sqlite-sakila-db/sakila.db"
15
+ urllib.request.urlretrieve(url, "./sakila.db")
16
+ print("✓ Sakila database downloaded")
17
+
18
+ def setup_agnews_chromadb():
19
+ """Load original AG News and compute embeddings."""
20
+ print("\nLoading AG News dataset...")
21
+
22
+ ds = load_dataset("fancyzhx/ag_news", split="train[:500]")
23
+ print(f"✓ Loaded {len(ds)} articles")
24
+
25
+ os.makedirs("./chroma_agnews/", exist_ok=True)
26
+ client = chromadb.PersistentClient(path="./chroma_agnews/")
27
+
28
+ try:
29
+ client.delete_collection("ag_news")
30
+ except:
31
+ pass
32
+
33
+ # Create collection with embedding function
34
+ embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
35
+ model_name="all-mpnet-base-v2"
36
+ )
37
+
38
+ collection = client.create_collection(
39
+ name="ag_news",
40
+ embedding_function=embedding_fn,
41
+ metadata={"hnsw:space": "cosine"}
42
+ )
43
+
44
+ # Label mapping
45
+ label_names = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
46
+
47
+ # Adding to ChromaDB
48
+ print("Computing embeddings and adding to ChromaDB...")
49
+
50
+ ids = [f"doc_{i}" for i in range(len(ds))]
51
+ documents = [item['text'] for item in ds]
52
+ metadatas = [{
53
+ "label": item['label'],
54
+ "label_text": label_names[item['label']],
55
+ "title": item['text'][:100] + "..." if len(item['text']) > 100 else item['text']
56
+ } for item in ds]
57
+
58
+ collection.add(
59
+ ids=ids,
60
+ documents=documents,
61
+ metadatas=metadatas
62
+ )
63
+
64
+ print(f"✓ Added {len(ds)} articles to ChromaDB")
65
+
66
+ if __name__ == "__main__":
67
+ print("=== Setting up databases ===\n")
68
+ download_sakila_db()
69
+ setup_agnews_chromadb()
70
+ print("\n Setup complete! Run 'streamlit run chatbot.py'")