Spaces:
Sleeping
Sleeping
ShantanuD
commited on
Commit
·
50bdeb1
1
Parent(s):
bca3e9c
Retriever Made
Browse files- app.py +79 -52
- requirements.txt +8 -1
app.py
CHANGED
@@ -1,64 +1,91 @@
|
|
1 |
-
import
|
2 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
"""
|
7 |
-
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
max_tokens,
|
15 |
-
temperature,
|
16 |
-
top_p,
|
17 |
-
):
|
18 |
-
messages = [{"role": "system", "content": system_message}]
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
stream=True,
|
34 |
-
temperature=temperature,
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
41 |
|
|
|
|
|
|
|
42 |
|
43 |
-
""
|
44 |
-
|
45 |
-
"""
|
46 |
-
demo = gr.ChatInterface(
|
47 |
-
respond,
|
48 |
-
additional_inputs=[
|
49 |
-
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
50 |
-
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
51 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
52 |
-
gr.Slider(
|
53 |
-
minimum=0.1,
|
54 |
-
maximum=1.0,
|
55 |
-
value=0.95,
|
56 |
-
step=0.05,
|
57 |
-
label="Top-p (nucleus sampling)",
|
58 |
-
),
|
59 |
-
],
|
60 |
-
)
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
if __name__ == "__main__":
|
64 |
-
demo.launch()
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from langchain_community.document_loaders import HuggingFaceDatasetLoader
|
3 |
+
from langchain_text_splitters import CharacterTextSplitter
|
4 |
+
from langchain.vectorstores import Chroma
|
5 |
+
from langchain_aws import BedrockEmbeddings
|
6 |
+
from langchain.chat_models import ChatBedrock
|
7 |
+
from langchain.schema import HumanMessage
|
8 |
+
import os
|
9 |
|
10 |
+
# Optional: For Cohere reranking
|
11 |
+
import cohere
|
|
|
|
|
12 |
|
13 |
+
# --- Load and Prepare Dataset (run once and cache)
|
14 |
+
@st.cache_resource
|
15 |
+
def load_and_index_data():
|
16 |
+
loader = HuggingFaceDatasetLoader(
|
17 |
+
dataset_name="Cohere/wikipedia-22-12-simple-embeddings",
|
18 |
+
page_content_column="text",
|
19 |
+
name="train",
|
20 |
+
load_max_docs=100
|
21 |
+
)
|
22 |
+
documents = loader.load()
|
23 |
+
|
24 |
+
splitter = CharacterTextSplitter(separator="\n", chunk_size=500, chunk_overlap=50)
|
25 |
+
chunks = splitter.split_documents(documents)
|
26 |
|
27 |
+
embedding = BedrockEmbeddings(
|
28 |
+
model_id="amazon.titan-embed-text-v1",
|
29 |
+
region_name="us-east-1"
|
30 |
+
)
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
vectordb = Chroma.from_documents(
|
33 |
+
documents=chunks,
|
34 |
+
embedding=embedding,
|
35 |
+
persist_directory="./chromadb"
|
36 |
+
)
|
37 |
+
vectordb.persist()
|
38 |
+
return vectordb
|
39 |
|
40 |
+
# --- Re-rank using Claude 3.5 via Bedrock
|
41 |
+
def rerank_with_claude(query, docs):
|
42 |
+
claude = ChatBedrock(
|
43 |
+
model_id="anthropic.claude-3-sonnet-20240229-v1:0",
|
44 |
+
region_name="us-east-1"
|
45 |
+
)
|
46 |
+
context = "\n\n".join([f"Document {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
|
47 |
+
prompt = f"""You are a helpful assistant tasked with re-ranking search results based on their relevance to a user query.
|
48 |
+
|
49 |
+
Query: {query}
|
50 |
|
51 |
+
Documents:
|
52 |
+
{context}
|
53 |
|
54 |
+
Please rank the documents in order of relevance to the query and explain briefly."""
|
55 |
+
response = claude([HumanMessage(content=prompt)])
|
56 |
+
return response.content
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
# --- Re-rank using Cohere
|
59 |
+
def rerank_with_cohere(query, docs):
|
60 |
+
co = cohere.Client(st.secrets["COHERE_API_KEY"])
|
61 |
+
documents = [doc.page_content for doc in docs]
|
62 |
+
results = co.rerank(query=query, documents=documents, top_n=5)
|
63 |
+
return results
|
64 |
|
65 |
+
# --- Streamlit UI
|
66 |
+
st.set_page_config(page_title="Re-ranking Demo", layout="wide")
|
67 |
+
st.title("🔎 Wikipedia Search with Re-ranking")
|
68 |
|
69 |
+
query = st.text_input("Enter your question:")
|
70 |
+
rerank_method = st.selectbox("Choose re-ranking method:", ["None (Baseline)", "Claude 3.5 (Bedrock)", "Cohere Rerank"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
+
if query:
|
73 |
+
vectordb = load_and_index_data()
|
74 |
+
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
|
75 |
+
baseline_results = retriever.get_relevant_documents(query)
|
76 |
+
|
77 |
+
st.subheader("🔍 Baseline Results")
|
78 |
+
for i, doc in enumerate(baseline_results):
|
79 |
+
st.markdown(f"**Doc {i+1}:** {doc.page_content[:300]}...")
|
80 |
+
|
81 |
+
if rerank_method == "Claude 3.5 (Bedrock)":
|
82 |
+
st.subheader("✨ Re-ranked Results (Claude 3.5)")
|
83 |
+
ranked_text = rerank_with_claude(query, baseline_results)
|
84 |
+
st.text(ranked_text)
|
85 |
+
|
86 |
+
elif rerank_method == "Cohere Rerank":
|
87 |
+
st.subheader("✨ Re-ranked Results (Cohere)")
|
88 |
+
reranked = rerank_with_cohere(query, baseline_results)
|
89 |
+
for i, result in enumerate(reranked.results):
|
90 |
+
st.markdown(f"**Doc {i+1}** (score: {result.relevance_score:.2f}):\n\n{result.document[:300]}...")
|
91 |
|
|
|
|
requirements.txt
CHANGED
@@ -1 +1,8 @@
|
|
1 |
-
huggingface_hub==0.25.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub==0.25.2
|
2 |
+
streamlit
|
3 |
+
langchain
|
4 |
+
langchain-community
|
5 |
+
langchain-aws
|
6 |
+
chromadb
|
7 |
+
cohere
|
8 |
+
boto3
|