Spaces:
Runtime error
Runtime error
loading scripts and app stuff
Browse filesCo-authored By: Daniel [email protected]
Co-authored By: Brandon [email protected]
Co-authored By: Enrico [email protected]
Co-authored By: Jinanshi [email protected]
- Dockerfile +3 -76
- RAG.py +156 -0
- load_pinecone.py +97 -0
- load_script.py +140 -0
- streamlit_app.py +148 -0
Dockerfile
CHANGED
|
@@ -1,111 +1,38 @@
|
|
| 1 |
-
FROM python:3.
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
|
| 9 |
# Create a non-root user
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
RUN useradd -m -u 1000 user
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
USER user
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
# Set PATH to include user's local bin
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
ENV PATH="/home/user/.local/bin:$PATH"
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
# Set working directory
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
WORKDIR /app
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
# Copy requirements file with appropriate ownership
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
COPY --chown=user ./requirements.txt requirements.txt
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
# Install dependencies
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
|
| 73 |
# Copy application files with appropriate ownership
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
COPY --chown=user . /app
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
# Set environment variables for Streamlit
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
ENV HOST=0.0.0.0
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
| 93 |
ENV PORT=7860
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
ENV STREAMLIT_SERVER_PORT=7860
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
ENV STREAMLIT_SERVER_ADDRESS=0.0.0.0
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
# Change the CMD to use chainlit
|
| 110 |
|
| 111 |
-
CMD ["streamlit", "run", "
|
|
|
|
| 1 |
+
FROM python:3.12.4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
# Create a non-root user
|
| 4 |
|
|
|
|
|
|
|
| 5 |
RUN useradd -m -u 1000 user
|
|
|
|
|
|
|
|
|
|
| 6 |
USER user
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
# Set PATH to include user's local bin
|
|
|
|
|
|
|
|
|
|
| 9 |
ENV PATH="/home/user/.local/bin:$PATH"
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# Set working directory
|
|
|
|
|
|
|
|
|
|
| 12 |
WORKDIR /app
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# Copy requirements file with appropriate ownership
|
|
|
|
|
|
|
|
|
|
| 15 |
COPY --chown=user ./requirements.txt requirements.txt
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
# Install dependencies
|
| 18 |
|
|
|
|
|
|
|
| 19 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 20 |
+
RUN pip install rank_bm25
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# Copy application files with appropriate ownership
|
| 23 |
|
|
|
|
|
|
|
| 24 |
COPY --chown=user . /app
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# Set environment variables for Streamlit
|
| 27 |
|
|
|
|
|
|
|
| 28 |
ENV HOST=0.0.0.0
|
| 29 |
|
|
|
|
|
|
|
| 30 |
ENV PORT=7860
|
| 31 |
|
|
|
|
|
|
|
| 32 |
ENV STREAMLIT_SERVER_PORT=7860
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
ENV STREAMLIT_SERVER_ADDRESS=0.0.0.0
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# Change the CMD to use chainlit
|
| 37 |
|
| 38 |
+
CMD ["streamlit", "run", "streamlit_app.py", "--server.port", "7860", "--server.address", "0.0.0.0"]
|
RAG.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import getpass
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from pinecone import Pinecone, ServerlessSpec
|
| 5 |
+
from langchain_pinecone import PineconeVectorStore
|
| 6 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
from langchain_core.prompts import PromptTemplate
|
| 9 |
+
from langchain_openai import ChatOpenAI
|
| 10 |
+
import re
|
| 11 |
+
from langchain_core.documents import Document
|
| 12 |
+
from langchain_community.retrievers import BM25Retriever
|
| 13 |
+
import requests
|
| 14 |
+
from typing import Dict, Any, Optional, List, Tuple
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
def retrieve(index_name: str, query: str, embeddings, k: int = 1000) -> Tuple[List[Document], List[float]]:
|
| 19 |
+
load_dotenv()
|
| 20 |
+
pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
| 21 |
+
pc = Pinecone(api_key=pinecone_api_key)
|
| 22 |
+
|
| 23 |
+
index = pc.Index(index_name)
|
| 24 |
+
vector_store = PineconeVectorStore(index=index, embedding=embeddings)
|
| 25 |
+
results = vector_store.similarity_search_with_score(
|
| 26 |
+
query,
|
| 27 |
+
k=k,
|
| 28 |
+
)
|
| 29 |
+
documents = []
|
| 30 |
+
scores = []
|
| 31 |
+
for res, score in results:
|
| 32 |
+
documents.append(res)
|
| 33 |
+
scores.append(score)
|
| 34 |
+
return documents, scores
|
| 35 |
+
|
| 36 |
+
def safe_get_json(url: str) -> Optional[Dict]:
|
| 37 |
+
"""Safely fetch and parse JSON from a URL."""
|
| 38 |
+
print("Fetching JSON")
|
| 39 |
+
try:
|
| 40 |
+
response = requests.get(url, timeout=10)
|
| 41 |
+
response.raise_for_status()
|
| 42 |
+
return response.json()
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logging.error(f"Error fetching from {url}: {str(e)}")
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
def extract_text_from_json(json_data: Dict) -> str:
|
| 48 |
+
"""Extract text content from JSON response."""
|
| 49 |
+
if not json_data:
|
| 50 |
+
return ""
|
| 51 |
+
|
| 52 |
+
text_parts = []
|
| 53 |
+
|
| 54 |
+
# Handle direct text fields
|
| 55 |
+
text_fields = ["title_info_primary_tsi","abstract_tsi","subject_geographic_sim","genre_specific_ssim"]
|
| 56 |
+
for field in text_fields:
|
| 57 |
+
if field in json_data['data']['attributes'] and json_data['data']['attributes'][field]:
|
| 58 |
+
# print(json_data[field])
|
| 59 |
+
text_parts.append(str(json_data['data']['attributes'][field]))
|
| 60 |
+
|
| 61 |
+
return " ".join(text_parts) if text_parts else "No content available"
|
| 62 |
+
|
| 63 |
+
def rerank(documents: List[Document], query: str) -> List[Document]:
|
| 64 |
+
"""Rerank documents using BM25, with proper error handling."""
|
| 65 |
+
if not documents:
|
| 66 |
+
return []
|
| 67 |
+
|
| 68 |
+
full_docs = []
|
| 69 |
+
for doc in documents:
|
| 70 |
+
if not doc.metadata.get('source'):
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
url = f"https://www.digitalcommonwealth.org/search/{doc.metadata['source']}"
|
| 74 |
+
json_data = safe_get_json(f"{url}.json")
|
| 75 |
+
|
| 76 |
+
if json_data:
|
| 77 |
+
text_content = extract_text_from_json(json_data)
|
| 78 |
+
if text_content: # Only add documents with actual content
|
| 79 |
+
full_docs.append(Document(page_content=text_content, metadata={"source":doc.metadata['source'],"field":doc.metadata['field'],"URL":url}))
|
| 80 |
+
|
| 81 |
+
# If no valid documents were processed, return empty list
|
| 82 |
+
if not full_docs:
|
| 83 |
+
return []
|
| 84 |
+
|
| 85 |
+
# Create BM25 retriever with the processed documents
|
| 86 |
+
reranker = BM25Retriever.from_documents(full_docs, k=min(10, len(full_docs)))
|
| 87 |
+
reranked_docs = reranker.invoke(query)
|
| 88 |
+
return reranked_docs
|
| 89 |
+
|
| 90 |
+
def parse_xml_and_check(xml_string: str) -> str:
|
| 91 |
+
"""Parse XML-style tags and handle validation."""
|
| 92 |
+
if not xml_string:
|
| 93 |
+
return "No response generated."
|
| 94 |
+
|
| 95 |
+
pattern = r"<(\w+)>(.*?)</\1>"
|
| 96 |
+
matches = re.findall(pattern, xml_string, re.DOTALL)
|
| 97 |
+
parsed_response = dict(matches)
|
| 98 |
+
|
| 99 |
+
if parsed_response.get('VALID') == 'NO':
|
| 100 |
+
return "Sorry, I was unable to find any documents relevant to your query."
|
| 101 |
+
|
| 102 |
+
return parsed_response.get('RESPONSE', "No response found in the output")
|
| 103 |
+
|
| 104 |
+
def RAG(llm: Any, query: str, index_name: str, embeddings: Any, top: int = 10, k: int = 100) -> Tuple[str, List[Document]]:
|
| 105 |
+
"""Main RAG function with improved error handling and validation."""
|
| 106 |
+
try:
|
| 107 |
+
# Retrieve initial documents
|
| 108 |
+
retrieved, _ = retrieve(index_name=index_name, query=query, embeddings=embeddings, k=k)
|
| 109 |
+
if not retrieved:
|
| 110 |
+
return "No documents found for your query.", []
|
| 111 |
+
|
| 112 |
+
# Rerank documents
|
| 113 |
+
reranked = rerank(documents=retrieved, query=query)
|
| 114 |
+
if not reranked:
|
| 115 |
+
return "Unable to process the retrieved documents.", []
|
| 116 |
+
|
| 117 |
+
# Prepare context from reranked documents
|
| 118 |
+
context = "\n\n".join(doc.page_content for doc in reranked[:top] if doc.page_content)
|
| 119 |
+
if not context.strip():
|
| 120 |
+
return "No relevant content found in the documents.", []
|
| 121 |
+
|
| 122 |
+
# Prepare prompt
|
| 123 |
+
prompt_template = PromptTemplate.from_template(
|
| 124 |
+
"""Pretend you are a professional librarian. Please Summarize The Following Context as though you had retrieved it for a patron:
|
| 125 |
+
Context:{context}
|
| 126 |
+
Make sure to answer in the following format
|
| 127 |
+
First, reason about the answer between <REASONING></REASONING> headers,
|
| 128 |
+
based on the context determine if there is sufficient material for answering the exact question,
|
| 129 |
+
return either <VALID>YES</VALID> or <VALID>NO</VALID>
|
| 130 |
+
then return a response between <RESPONSE></RESPONSE> headers:
|
| 131 |
+
Here is an example
|
| 132 |
+
<EXAMPLE>
|
| 133 |
+
<QUERY>Are pineapples a good fuel for cars?</QUERY>
|
| 134 |
+
<CONTEXT>Cars use gasoline for fuel. Some cars use electricity for fuel.Tesla stock has increased by 10 percent over the last quarter.</CONTEXT>
|
| 135 |
+
<REASONING>Based on the context pineapples have not been explored as a fuel for cars. The context discusses gasoline, electricity, and tesla stock, therefore it is not relevant to the query about pineapples for fuel</REASONING>
|
| 136 |
+
<VALID>NO</VALID>
|
| 137 |
+
<RESPONSE>Pineapples are not a good fuel for cars, however with further researach they migth be</RESPONSE>
|
| 138 |
+
</EXAMPLE>
|
| 139 |
+
Now it's your turn
|
| 140 |
+
<QUERY>
|
| 141 |
+
{query}
|
| 142 |
+
</QUERY>"""
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Generate response
|
| 146 |
+
prompt = prompt_template.invoke({"context": context, "query": query})
|
| 147 |
+
print(prompt)
|
| 148 |
+
response = llm.invoke(prompt)
|
| 149 |
+
|
| 150 |
+
# Parse and return response
|
| 151 |
+
parsed = parse_xml_and_check(response.content)
|
| 152 |
+
return parsed, reranked
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logging.error(f"Error in RAG function: {str(e)}")
|
| 156 |
+
return f"An error occurred while processing your query: {str(e)}", []
|
load_pinecone.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 2 |
+
from langchain_pinecone import PineconeVectorStore
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 5 |
+
from pinecone import Pinecone, ServerlessSpec
|
| 6 |
+
from uuid import uuid4
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
import time
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
BEGIN = int(sys.argv[1])
|
| 17 |
+
END = int(sys.argv[2])
|
| 18 |
+
PATH = sys.argv[3]
|
| 19 |
+
|
| 20 |
+
# Pinecone setup
|
| 21 |
+
|
| 22 |
+
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
|
| 23 |
+
pc = Pinecone(api_key=PINECONE_API_KEY)
|
| 24 |
+
INDEX_NAME = sys.argv[4]
|
| 25 |
+
index = pc.Index(INDEX_NAME)
|
| 26 |
+
|
| 27 |
+
print("Loading JSON...")
|
| 28 |
+
meta = json.load(open(PATH))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
| 32 |
+
model_kwargs = {'device': 'cuda'}
|
| 33 |
+
encode_kwargs = {'normalize_embeddings': False}
|
| 34 |
+
|
| 35 |
+
print("Initializing Pinecone index...")
|
| 36 |
+
|
| 37 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 38 |
+
vector_store = PineconeVectorStore(index=index, embedding=embeddings)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 43 |
+
|
| 44 |
+
chunk_size=1000,
|
| 45 |
+
|
| 46 |
+
chunk_overlap=100,
|
| 47 |
+
|
| 48 |
+
length_function=len,
|
| 49 |
+
|
| 50 |
+
separators=["\n\n", "\n", " ", ""]
|
| 51 |
+
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
fields = ['abstract_tsi','title_info_primary_tsi','title_info_primary_subtitle_tsi', 'title_info_alternative_tsim']
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
print("Beginning Embeddings...")
|
| 58 |
+
|
| 59 |
+
start = time.time()
|
| 60 |
+
|
| 61 |
+
full_data = []
|
| 62 |
+
|
| 63 |
+
for page in meta:
|
| 64 |
+
content = page['data']
|
| 65 |
+
full_data += content
|
| 66 |
+
if BEGIN > END:
|
| 67 |
+
slice = content[BEGIN:]
|
| 68 |
+
else:
|
| 69 |
+
slice = content[BEGIN:END]
|
| 70 |
+
|
| 71 |
+
num = 0
|
| 72 |
+
|
| 73 |
+
for item in slice:
|
| 74 |
+
|
| 75 |
+
id = item["id"]
|
| 76 |
+
item_data = item["attributes"]
|
| 77 |
+
print(id, time.time())
|
| 78 |
+
documents = []
|
| 79 |
+
for field in item_data:
|
| 80 |
+
if (field in fields) or ("note" in field):
|
| 81 |
+
entry = str(item_data[field])
|
| 82 |
+
if len(entry) > 1000:
|
| 83 |
+
chunks = text_splitter.split_text(entry)
|
| 84 |
+
for chunk in chunks:
|
| 85 |
+
documents.append(Document(page_content=chunk, metadata={"source": id, "field": field}))
|
| 86 |
+
else:
|
| 87 |
+
documents.append(Document(page_content=entry, metadata={"source": id, "field": field}))
|
| 88 |
+
|
| 89 |
+
if num % 1000 == 0:
|
| 90 |
+
print(num, f"Added vectors to vectorstore at {time.time()} on id {id}")
|
| 91 |
+
print(documents)
|
| 92 |
+
uuids = [str(uuid4()) for _ in range(len(documents))]
|
| 93 |
+
vector_store.add_documents(documents=documents, ids=uuids)
|
| 94 |
+
num += 1
|
| 95 |
+
|
| 96 |
+
end = time.time()
|
| 97 |
+
print(f"Embedded all documents in {end-start} seconds...")
|
load_script.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
import requests
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def fetch_digital_commonwealth():
|
| 15 |
+
|
| 16 |
+
start = time.time()
|
| 17 |
+
|
| 18 |
+
BASE_URL = "https://www.digitalcommonwealth.org/search.json?search_field=all_fields&per_page=100&q="
|
| 19 |
+
|
| 20 |
+
PAGE = sys.argv[1]
|
| 21 |
+
|
| 22 |
+
END_PAGE = sys.argv[2]
|
| 23 |
+
|
| 24 |
+
file_name = f"out{PAGE}_{END_PAGE}.json"
|
| 25 |
+
|
| 26 |
+
FINAL_PAGE = 13038
|
| 27 |
+
|
| 28 |
+
output = []
|
| 29 |
+
|
| 30 |
+
file_path = f"./{file_name}"
|
| 31 |
+
|
| 32 |
+
# file_path = './output.json'
|
| 33 |
+
|
| 34 |
+
if os.path.exists(file_path):
|
| 35 |
+
|
| 36 |
+
with open(file_path,'r') as file:
|
| 37 |
+
|
| 38 |
+
output = json.load(file)
|
| 39 |
+
|
| 40 |
+
if int(PAGE) < (len(output) + 1):
|
| 41 |
+
|
| 42 |
+
PAGE = len(output) + 1
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if int(PAGE) >= int(END_PAGE):
|
| 47 |
+
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
print(f'Reading page {PAGE} up to page {END_PAGE}')
|
| 51 |
+
|
| 52 |
+
retries = 0
|
| 53 |
+
|
| 54 |
+
while True:
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
|
| 58 |
+
response = requests.get(f"{BASE_URL}&page={PAGE}")
|
| 59 |
+
|
| 60 |
+
response.raise_for_status()
|
| 61 |
+
|
| 62 |
+
data = response.json()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Append current page data to the output list
|
| 67 |
+
|
| 68 |
+
output.append(data)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# Save the entire output to a JSON file after each iteration
|
| 73 |
+
|
| 74 |
+
with open(file_path, 'w') as f:
|
| 75 |
+
|
| 76 |
+
json.dump(output, f)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# check if theres a next page
|
| 83 |
+
|
| 84 |
+
# print(len(response))
|
| 85 |
+
|
| 86 |
+
if data['meta']['pages']['next_page']:
|
| 87 |
+
|
| 88 |
+
if data['meta']['pages']['next_page'] == int(END_PAGE):
|
| 89 |
+
|
| 90 |
+
print(f"Processed and saved page {PAGE}. Total pages saved: {len(output)}")
|
| 91 |
+
|
| 92 |
+
break
|
| 93 |
+
|
| 94 |
+
elif data['meta']['pages']['next_page'] == FINAL_PAGE:
|
| 95 |
+
|
| 96 |
+
print(f"finished page {PAGE}")
|
| 97 |
+
|
| 98 |
+
PAGE = FINAL_PAGE
|
| 99 |
+
|
| 100 |
+
else:
|
| 101 |
+
|
| 102 |
+
print(f"finished page {PAGE}")
|
| 103 |
+
|
| 104 |
+
PAGE = data['meta']['pages']['next_page']
|
| 105 |
+
|
| 106 |
+
else:
|
| 107 |
+
|
| 108 |
+
print(f"Processed and saved page {PAGE}. Total pages saved: {len(output)}")
|
| 109 |
+
|
| 110 |
+
break
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
retries = 0
|
| 115 |
+
|
| 116 |
+
# Optional: Add a small delay to avoid overwhelming the API
|
| 117 |
+
|
| 118 |
+
# time.sleep(0.5)
|
| 119 |
+
|
| 120 |
+
except requests.exceptions.RequestException as e:
|
| 121 |
+
|
| 122 |
+
print(f"An error occurred: {e}")
|
| 123 |
+
|
| 124 |
+
print(f"Processed and saved page {PAGE}. Total pages saved: {len(output)}")
|
| 125 |
+
|
| 126 |
+
retries += 1
|
| 127 |
+
|
| 128 |
+
if retries >= 5:
|
| 129 |
+
|
| 130 |
+
break
|
| 131 |
+
|
| 132 |
+
end = time.time()
|
| 133 |
+
|
| 134 |
+
print(f"Timer: {end - start}")
|
| 135 |
+
|
| 136 |
+
print(f"Finished processing all pages. Total pages saved: {len(output)}")
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
|
| 140 |
+
fetch_digital_commonwealth()
|
streamlit_app.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Tuple, Optional
|
| 4 |
+
from pinecone import Pinecone
|
| 5 |
+
from langchain_pinecone import PineconeVectorStore
|
| 6 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 7 |
+
from langchain_openai import ChatOpenAI
|
| 8 |
+
from langchain_core.prompts import PromptTemplate
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
from RAG import RAG
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
# Configure logging
|
| 14 |
+
logging.basicConfig(level=logging.INFO)
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# Page configuration
|
| 18 |
+
st.set_page_config(
|
| 19 |
+
page_title="Boston Public Library Chatbot",
|
| 20 |
+
page_icon="🤖",
|
| 21 |
+
layout="wide"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def initialize_models() -> Tuple[Optional[ChatOpenAI], HuggingFaceEmbeddings]:
|
| 25 |
+
"""Initialize the language model and embeddings."""
|
| 26 |
+
try:
|
| 27 |
+
load_dotenv()
|
| 28 |
+
|
| 29 |
+
# Initialize OpenAI model
|
| 30 |
+
llm = ChatOpenAI(
|
| 31 |
+
model="gpt-4", # Changed from gpt-4o-mini which appears to be a typo
|
| 32 |
+
temperature=0,
|
| 33 |
+
timeout=60, # Added reasonable timeout
|
| 34 |
+
max_retries=2
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Initialize embeddings
|
| 38 |
+
embeddings = HuggingFaceEmbeddings(
|
| 39 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
return llm, embeddings
|
| 43 |
+
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(f"Error initializing models: {str(e)}")
|
| 46 |
+
st.error(f"Failed to initialize models: {str(e)}")
|
| 47 |
+
return None, None
|
| 48 |
+
|
| 49 |
+
def process_message(
|
| 50 |
+
query: str,
|
| 51 |
+
llm: ChatOpenAI,
|
| 52 |
+
index_name: str,
|
| 53 |
+
embeddings: HuggingFaceEmbeddings
|
| 54 |
+
) -> Tuple[str, List]:
|
| 55 |
+
"""Process the user message using the RAG system."""
|
| 56 |
+
try:
|
| 57 |
+
response, sources = RAG(
|
| 58 |
+
query=query,
|
| 59 |
+
llm=llm,
|
| 60 |
+
index_name=index_name,
|
| 61 |
+
embeddings=embeddings
|
| 62 |
+
)
|
| 63 |
+
return response, sources
|
| 64 |
+
except Exception as e:
|
| 65 |
+
logger.error(f"Error in process_message: {str(e)}")
|
| 66 |
+
return f"Error processing message: {str(e)}", []
|
| 67 |
+
|
| 68 |
+
def display_sources(sources: List) -> None:
|
| 69 |
+
"""Display sources in expandable sections with proper formatting."""
|
| 70 |
+
if not sources:
|
| 71 |
+
st.info("No sources available for this response.")
|
| 72 |
+
return
|
| 73 |
+
|
| 74 |
+
st.subheader("Sources")
|
| 75 |
+
for i, doc in enumerate(sources, 1):
|
| 76 |
+
try:
|
| 77 |
+
with st.expander(f"Source {i}"):
|
| 78 |
+
if hasattr(doc, 'page_content'):
|
| 79 |
+
st.markdown(f"**Content:** {doc.page_content}")
|
| 80 |
+
if hasattr(doc, 'metadata'):
|
| 81 |
+
for key, value in doc.metadata.items():
|
| 82 |
+
st.markdown(f"**{key.title()}:** {value}")
|
| 83 |
+
else:
|
| 84 |
+
st.markdown(f"**Content:** {str(doc)}")
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"Error displaying source {i}: {str(e)}")
|
| 87 |
+
st.error(f"Error displaying source {i}")
|
| 88 |
+
|
| 89 |
+
def main():
|
| 90 |
+
st.title("RAG Chatbot")
|
| 91 |
+
|
| 92 |
+
# Initialize session state
|
| 93 |
+
if "messages" not in st.session_state:
|
| 94 |
+
st.session_state.messages = []
|
| 95 |
+
|
| 96 |
+
# Initialize models
|
| 97 |
+
llm, embeddings = initialize_models()
|
| 98 |
+
if not llm or not embeddings:
|
| 99 |
+
st.error("Failed to initialize the application. Please check the logs.")
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
# Constants
|
| 103 |
+
INDEX_NAME = 'bpl-rag'
|
| 104 |
+
|
| 105 |
+
# Display chat history
|
| 106 |
+
for message in st.session_state.messages:
|
| 107 |
+
with st.chat_message(message["role"]):
|
| 108 |
+
st.markdown(message["content"])
|
| 109 |
+
|
| 110 |
+
# Chat input
|
| 111 |
+
user_input = st.chat_input("Type your message here...")
|
| 112 |
+
if user_input:
|
| 113 |
+
# Display user message
|
| 114 |
+
with st.chat_message("user"):
|
| 115 |
+
st.markdown(user_input)
|
| 116 |
+
st.session_state.messages.append({"role": "user", "content": user_input})
|
| 117 |
+
|
| 118 |
+
# Process and display assistant response
|
| 119 |
+
with st.chat_message("assistant"):
|
| 120 |
+
with st.spinner("Let Me Think..."):
|
| 121 |
+
response, sources = process_message(
|
| 122 |
+
query=user_input,
|
| 123 |
+
llm=llm,
|
| 124 |
+
index_name=INDEX_NAME,
|
| 125 |
+
embeddings=embeddings
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if isinstance(response, str):
|
| 129 |
+
st.markdown(response)
|
| 130 |
+
st.session_state.messages.append({
|
| 131 |
+
"role": "assistant",
|
| 132 |
+
"content": response
|
| 133 |
+
})
|
| 134 |
+
|
| 135 |
+
# Display sources
|
| 136 |
+
display_sources(sources)
|
| 137 |
+
else:
|
| 138 |
+
st.error("Received an invalid response format")
|
| 139 |
+
|
| 140 |
+
# Footer
|
| 141 |
+
st.markdown("---")
|
| 142 |
+
st.markdown(
|
| 143 |
+
"Built with ❤️ using Streamlit + LangChain + OpenAI",
|
| 144 |
+
help="An AI-powered chatbot with RAG capabilities"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
main()
|