|
import gradio as gr |
|
import os |
|
from groq import Groq |
|
|
|
|
|
import pandas as pd |
|
from datasets import Dataset |
|
|
|
|
|
test_dataset_df = pd.DataFrame(columns=['id', 'title', 'content', 'prechunk_id', 'postchunk_id', 'arxiv_id', 'references']) |
|
|
|
|
|
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{ |
|
'id': '1', |
|
'title': 'Best restaurants in queens', |
|
'content': 'I personally like to go to the Pan Chicken, they have fried chicken and amazing bubble tea.', |
|
'prechunk_id': '', |
|
'postchunk_id': '2', |
|
'arxiv_id': '2401.04088', |
|
'references': ['arXiv:9012.3456', 'arXiv:7890.1234'] |
|
}])], ignore_index=True) |
|
|
|
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{ |
|
'id': '2', |
|
'title': 'Best restaurants in queens', |
|
'content': 'if you like asian food, flushing is second to none.', |
|
'prechunk_id': '1', |
|
'postchunk_id': '3', |
|
'arxiv_id': '2401.04088', |
|
'references': ['arXiv:6543.2109', 'arXiv:3210.9876'] |
|
}])], ignore_index=True) |
|
|
|
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{ |
|
'id': '3', |
|
'title': 'Best restaurants in queens', |
|
'content': 'you have to try the baked ziti from EC', |
|
'prechunk_id': '2', |
|
'postchunk_id': '', |
|
'arxiv_id': '2401.04088', |
|
'references': ['arXiv:1234.5678', 'arXiv:9012.3456'] |
|
}])], ignore_index=True) |
|
|
|
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{ |
|
'id': '4', |
|
'title': 'Spending a saturday in queens; what to do?', |
|
'content': 'theres a hidden gem called The Lounge, you can play poker and blackjack and darts', |
|
'prechunk_id': '', |
|
'postchunk_id': '5', |
|
'arxiv_id': '2401.04088', |
|
'references': ['arXiv:1234.5678', 'arXiv:9012.3456'] |
|
}])], ignore_index=True) |
|
|
|
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{ |
|
'id': '5', |
|
'title': 'Spending a saturday in queens; what to do?', |
|
'content': 'if its a nice day, basketball at Non-non-Fiction Park is always fun', |
|
'prechunk_id': '', |
|
'postchunk_id': '6', |
|
'arxiv_id': '2401.04088', |
|
'references': ['arXiv:1234.5678', 'arXiv:9012.3456'] |
|
}])], ignore_index=True) |
|
|
|
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{ |
|
'id': '7', |
|
'title': 'visiting queens for the weekend, how to get around?', |
|
'content': 'nothing beats the subway, even with delays its the fastest option. you can transfer between the bus and subway with one swipe', |
|
'prechunk_id': '', |
|
'postchunk_id': '8', |
|
'arxiv_id': '2401.04088', |
|
'references': ['arXiv:1234.5678', 'arXiv:9012.3456'] |
|
}])], ignore_index=True) |
|
|
|
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{ |
|
'id': '8', |
|
'title': 'visiting queens for the weekend, how to get around?', |
|
'content': 'if youre going to the bar, its honestly worth ubering there. MTA while drunk isnt something id recommend.', |
|
'prechunk_id': '7', |
|
'postchunk_id': '', |
|
'arxiv_id': '2401.04088', |
|
'references': ['arXiv:1234.5678', 'arXiv:9012.3456'] |
|
}])], ignore_index=True) |
|
|
|
|
|
test_dataset = Dataset.from_pandas(test_dataset_df) |
|
|
|
data = test_dataset |
|
|
|
data = data.map(lambda x: { |
|
"id": x["id"], |
|
"metadata": { |
|
"title": x["title"], |
|
"content": x["content"], |
|
} |
|
}) |
|
|
|
data = data.remove_columns([ |
|
"title", "content", "prechunk_id", |
|
"postchunk_id", "arxiv_id", "references" |
|
]) |
|
|
|
from semantic_router.encoders import HuggingFaceEncoder |
|
|
|
encoder = HuggingFaceEncoder(name="dwzhu/e5-base-4k") |
|
|
|
embeds = encoder(["this is a test"]) |
|
dims = len(embeds[0]) |
|
|
|
|
|
|
|
import os |
|
import getpass |
|
from pinecone import Pinecone |
|
|
|
|
|
api_key = os.getenv("PINECONE_API_KEY") |
|
|
|
|
|
pc = Pinecone(api_key=api_key) |
|
|
|
from pinecone import ServerlessSpec |
|
|
|
spec = ServerlessSpec( |
|
cloud="aws", region="us-east-1" |
|
) |
|
|
|
import time |
|
|
|
index_name = "groq-llama-3-rag" |
|
existing_indexes = [ |
|
index_info["name"] for index_info in pc.list_indexes() |
|
] |
|
|
|
|
|
if index_name not in existing_indexes: |
|
|
|
pc.create_index( |
|
index_name, |
|
dimension=dims, |
|
metric='cosine', |
|
spec=spec |
|
) |
|
|
|
while not pc.describe_index(index_name).status['ready']: |
|
time.sleep(1) |
|
|
|
|
|
index = pc.Index(index_name) |
|
time.sleep(1) |
|
|
|
index.describe_index_stats() |
|
|
|
from tqdm.auto import tqdm |
|
|
|
batch_size = 2 |
|
|
|
for i in tqdm(range(0, len(data), batch_size)): |
|
|
|
i_end = min(len(data), i+batch_size) |
|
|
|
batch = data[i:i_end] |
|
|
|
chunks = [f'{x["title"]}: {x["content"]}' for x in batch["metadata"]] |
|
embeds = encoder(chunks) |
|
assert len(embeds) == (i_end-i) |
|
to_upsert = list(zip(batch["id"], embeds, batch["metadata"])) |
|
|
|
index.upsert(vectors=to_upsert) |
|
|
|
def get_docs(query: str, top_k: int) -> list[str]: |
|
|
|
xq = encoder([query]) |
|
|
|
res = index.query(vector=xq, top_k=top_k, include_metadata=True) |
|
|
|
docs = [x["metadata"]['content'] for x in res["matches"]] |
|
return docs |
|
|
|
from groq import Groq |
|
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY")) |
|
|
|
def generate(query: str, history): |
|
|
|
|
|
if not history: |
|
system_message = ( |
|
"You are a friendly and knowledgeable New Yorker who loves giving recommendations about the city. " |
|
"You have lived in NYC for years and know both the famous tourist spots and the hidden local gems. " |
|
"You specialize in giving recommendations that match people's interests, whether they are looking for iconic attractions or off-the-beaten-path experiences. " |
|
"Use the provided context to enhance your responses with insights from real locals, but if the context doesn't fully cover the question, rely on your own knowledge of NYC to give the best possible answer. " |
|
"Keep your tone warm, conversational, and engaging, like a friend who genuinely enjoys sharing their city. " |
|
"Be specific when recommending places—mention neighborhoods, atmosphere, and why someone might like a spot. \n" |
|
"CONTEXT:\n" |
|
"\n---\n".join(get_docs(query, top_k=5)) |
|
) |
|
messages = [ |
|
{"role": "system", "content": system_message}, |
|
] |
|
else: |
|
|
|
messages = [] |
|
for user_msg, bot_msg in history: |
|
messages.append({"role": "user", "content": user_msg}) |
|
messages.append({"role": "assistant", "content": bot_msg}) |
|
messages.append({"role": "assistant", "content": bot_msg}) |
|
system_message = ( |
|
"Here is additional context based on the newest query.\n\n" |
|
"CONTEXT:\n" |
|
"\n---\n".join(get_docs(query, top_k=5)) |
|
) |
|
messages.append({"role": "system", "content": system_message}) |
|
|
|
|
|
messages.append({"role": "user", "content": query}) |
|
|
|
|
|
chat_response = groq_client.chat.completions.create( |
|
model="llama3-70b-8192", |
|
messages=messages |
|
) |
|
return chat_response.choices[0].message.content |
|
|
|
|
|
|
|
custom_css = """ |
|
.gradio-container { |
|
background: transparent !important; |
|
} |
|
.chat-message { |
|
display: flex; |
|
align-items: center; |
|
margin-bottom: 10px; |
|
} |
|
.chat-message.user { |
|
justify-content: flex-end; |
|
} |
|
.chat-message.assistant { |
|
justify-content: flex-start; |
|
} |
|
.chat-bubble { |
|
padding: 10px 15px; |
|
border-radius: 20px; |
|
max-width: 70%; |
|
font-size: 16px; |
|
display: inline-block; |
|
} |
|
.chat-bubble.user { |
|
background-color: #007aff; |
|
color: white; |
|
border-bottom-right-radius: 5px; |
|
} |
|
.chat-bubble.assistant { |
|
background-color: #f0f0f0; |
|
color: black; |
|
border-bottom-left-radius: 5px; |
|
} |
|
.profile-pic { |
|
width: 40px; |
|
height: 40px; |
|
border-radius: 50%; |
|
margin: 0 10px; |
|
} |
|
""" |
|
|
|
|
|
demo = gr.ChatInterface(generate, css=custom_css) |
|
|
|
demo.launch() |
|
|