NYC-Buddy / app.py
mt3842ml's picture
Update app.py
d8a0fee verified
raw
history blame
8.97 kB
import gradio as gr
import os
from groq import Groq
############ TESTING ############
import pandas as pd
from datasets import Dataset
# Define the dataset schema
test_dataset_df = pd.DataFrame(columns=['id', 'title', 'content', 'prechunk_id', 'postchunk_id', 'arxiv_id', 'references'])
# Populate the dataset with examples
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '1',
'title': 'Best restaurants in queens',
'content': 'I personally like to go to the J-Pan Chicken, they have fried chicken and amazing bubble tea.',
'prechunk_id': '',
'postchunk_id': '2',
'arxiv_id': '2401.04088',
'references': ['arXiv:9012.3456', 'arXiv:7890.1234']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '2',
'title': 'Best restaurants in queens',
'content': 'if you like asian food, flushing is second to none.',
'prechunk_id': '1',
'postchunk_id': '3',
'arxiv_id': '2401.04088',
'references': ['arXiv:6543.2109', 'arXiv:3210.9876']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '3',
'title': 'Best restaurants in queens',
'content': 'you have to try the ziti from ECC',
'prechunk_id': '2',
'postchunk_id': '',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '6',
'title': 'Best restaurants in queens',
'content': 'theres a good halal cart on Wub Street, they give extra sticky creamy white sauce',
'prechunk_id': '',
'postchunk_id': '',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '4',
'title': 'Spending a saturday in queens; what to do?',
'content': 'theres a hidden gem called The Lounge, you can play poker and blackjack and darts',
'prechunk_id': '',
'postchunk_id': '5',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '5',
'title': 'Spending a saturday in queens; what to do?',
'content': 'if its a nice day, basketball at Non-non-Fiction Park is always fun',
'prechunk_id': '',
'postchunk_id': '6',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '7',
'title': 'visiting queens for the weekend, how to get around?',
'content': 'nothing beats the subway, even with delays its the fastest option. you can transfer between the bus and subway with one swipe',
'prechunk_id': '',
'postchunk_id': '8',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '8',
'title': 'visiting queens for the weekend, how to get around?',
'content': 'if youre going to the bar, its honestly worth ubering there. MTA while drunk isnt something id recommend.',
'prechunk_id': '7',
'postchunk_id': '',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
# Convert the DataFrame to a Hugging Face Dataset object
test_dataset = Dataset.from_pandas(test_dataset_df)
data = test_dataset
data = data.map(lambda x: {
"id": x["id"],
"metadata": {
"title": x["title"],
"content": x["content"],
}
})
# drop uneeded columns
data = data.remove_columns([
"title", "content", "prechunk_id",
"postchunk_id", "arxiv_id", "references"
])
from semantic_router.encoders import HuggingFaceEncoder
encoder = HuggingFaceEncoder(name="dwzhu/e5-base-4k")
embeds = encoder(["this is a test"])
dims = len(embeds[0])
############ TESTING ############
import os
import getpass
from pinecone import Pinecone
# initialize connection to pinecone (get API key at app.pinecone.io)
api_key = os.getenv("PINECONE_API_KEY")
# configure client
pc = Pinecone(api_key=api_key)
from pinecone import ServerlessSpec
spec = ServerlessSpec(
cloud="aws", region="us-east-1"
)
import time
index_name = "groq-llama-3-rag"
existing_indexes = [
index_info["name"] for index_info in pc.list_indexes()
]
# check if index already exists (it shouldn't if this is first time)
if index_name not in existing_indexes:
# if does not exist, create index
pc.create_index(
index_name,
dimension=dims,
metric='cosine',
spec=spec
)
# wait for index to be initialized
while not pc.describe_index(index_name).status['ready']:
time.sleep(1)
# connect to index
index = pc.Index(index_name)
time.sleep(1)
# view index stats
index.describe_index_stats()
from tqdm.auto import tqdm
batch_size = 2 # how many embeddings we create and insert at once
for i in tqdm(range(0, len(data), batch_size)):
# find end of batch
i_end = min(len(data), i+batch_size)
# create batch
batch = data[i:i_end]
# create embeddings
chunks = [f'{x["title"]}: {x["content"]}' for x in batch["metadata"]]
embeds = encoder(chunks)
assert len(embeds) == (i_end-i)
to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
# upsert to Pinecone
index.upsert(vectors=to_upsert)
def get_docs(query: str, top_k: int) -> list[str]:
# encode query
xq = encoder([query])
# search pinecone index
res = index.query(vector=xq, top_k=top_k, include_metadata=True)
# get doc text
docs = [x["metadata"]['content'] for x in res["matches"]]
return docs
from groq import Groq
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
def generate(query: str, history):
# Create system message
if not history:
system_message = (
"You are a friendly and knowledgeable New Yorker who loves sharing recommendations about the city. "
"You have lived in NYC for years and know both the famous tourist spots and hidden local gems. "
"Your goal is to give recommendations tailored to what the user is asking for, whether they want iconic attractions "
"or lesser-known spots loved by locals.\n\n"
"Use the provided context to enhance your responses with real local insights, but only include details that are relevant "
"to the user’s question. If the context provides useful recommendations that match what the user is asking for, use them. "
"If the context is unrelated or does not fully answer the question, rely on your general NYC knowledge instead.\n\n"
"Be specific when recommending places—mention neighborhoods, the atmosphere, and why someone might like a spot. "
"Keep your tone warm, conversational, and engaging, like a close friend who genuinely enjoys sharing their city.\n\n"
"CONTEXT:\n"
"\n---\n".join(get_docs(query, top_k=5))
)
messages = [
{"role": "system", "content": system_message},
]
else:
# Establish history
messages = []
for user_msg, bot_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "assistant", "content": bot_msg})
system_message = (
"Here is additional context based on the newest query.\n\n"
"CONTEXT:\n"
"\n---\n".join(get_docs(query, top_k=5))
)
messages.append({"role": "system", "content": system_message})
# Add query
messages.append({"role": "user", "content": query})
# generate response
chat_response = groq_client.chat.completions.create(
model="llama3-70b-8192",
messages=messages
)
return chat_response.choices[0].message.content
# Custom CSS for iPhone-style chat
custom_css = """
.gradio-container {
background: transparent !important;
}
.chat-message {
display: flex;
align-items: center;
margin-bottom: 10px;
}
.chat-message.user {
justify-content: flex-end;
}
.chat-message.assistant {
justify-content: flex-start;
}
.chat-bubble {
padding: 10px 15px;
border-radius: 20px;
max-width: 70%;
font-size: 16px;
display: inline-block;
}
.chat-bubble.user {
background-color: #007aff;
color: white;
border-bottom-right-radius: 5px;
}
.chat-bubble.assistant {
background-color: #f0f0f0;
color: black;
border-bottom-left-radius: 5px;
}
.profile-pic {
width: 40px;
height: 40px;
border-radius: 50%;
margin: 0 10px;
}
"""
# Gradio Interface
demo = gr.ChatInterface(generate, css=custom_css)
demo.launch()