NYC-Buddy / app.py
mt3842ml's picture
Update app.py
47ac53b verified
raw
history blame
7.32 kB
import gradio as gr
import os
from groq import Groq
############ TESTING ############
import pandas as pd
from datasets import Dataset
# Define the dataset schema
test_dataset_df = pd.DataFrame(columns=['id', 'title', 'content', 'prechunk_id', 'postchunk_id', 'arxiv_id', 'references'])
# Populate the dataset with examples
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '1',
'title': 'Best restaurants in queens',
'content': 'I personally like to go to the Pan Chicken, they have fried chicken and amazing bubble tea.',
'prechunk_id': '',
'postchunk_id': '2',
'arxiv_id': '2401.04088',
'references': ['arXiv:9012.3456', 'arXiv:7890.1234']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '2',
'title': 'Best restaurants in queens',
'content': 'if you like asian food, flushing is second to none.',
'prechunk_id': '1',
'postchunk_id': '3',
'arxiv_id': '2401.04088',
'references': ['arXiv:6543.2109', 'arXiv:3210.9876']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '3',
'title': 'Best restaurants in queens',
'content': 'you have to try the baked ziti from EC',
'prechunk_id': '2',
'postchunk_id': '',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '4',
'title': 'Spending a saturday in queens; what to do?',
'content': 'theres a hidden gem called The Lounge, you can play poker and blackjack and darts',
'prechunk_id': '',
'postchunk_id': '5',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '5',
'title': 'Spending a saturday in queens; what to do?',
'content': 'if its a nice day, basketball at Non-non-Fiction Park is always fun',
'prechunk_id': '',
'postchunk_id': '6',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '7',
'title': 'visiting queens for the weekend, how to get around?',
'content': 'nothing beats the subway, even with delays its the fastest option. you can transfer between the bus and subway with one swipe',
'prechunk_id': '',
'postchunk_id': '8',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '8',
'title': 'visiting queens for the weekend, how to get around?',
'content': 'if youre going to the bar, its honestly worth ubering there. MTA while drunk isnt something id recommend.',
'prechunk_id': '7',
'postchunk_id': '',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
# Convert the DataFrame to a Hugging Face Dataset object
test_dataset = Dataset.from_pandas(test_dataset_df)
data = test_dataset
data = data.map(lambda x: {
"id": x["id"],
"metadata": {
"title": x["title"],
"content": x["content"],
}
})
# drop uneeded columns
data = data.remove_columns([
"title", "content", "prechunk_id",
"postchunk_id", "arxiv_id", "references"
])
from semantic_router.encoders import HuggingFaceEncoder
encoder = HuggingFaceEncoder(name="dwzhu/e5-base-4k")
embeds = encoder(["this is a test"])
dims = len(embeds[0])
############ TESTING ############
import os
import getpass
from pinecone import Pinecone
# initialize connection to pinecone (get API key at app.pinecone.io)
api_key = os.getenv("PINECONE_API_KEY")
# configure client
pc = Pinecone(api_key=api_key)
from pinecone import ServerlessSpec
spec = ServerlessSpec(
cloud="aws", region="us-east-1"
)
import time
index_name = "groq-llama-3-rag"
existing_indexes = [
index_info["name"] for index_info in pc.list_indexes()
]
# check if index already exists (it shouldn't if this is first time)
if index_name not in existing_indexes:
# if does not exist, create index
pc.create_index(
index_name,
dimension=dims,
metric='cosine',
spec=spec
)
# wait for index to be initialized
while not pc.describe_index(index_name).status['ready']:
time.sleep(1)
# connect to index
index = pc.Index(index_name)
time.sleep(1)
# view index stats
index.describe_index_stats()
from tqdm.auto import tqdm
batch_size = 2 # how many embeddings we create and insert at once
for i in tqdm(range(0, len(data), batch_size)):
# find end of batch
i_end = min(len(data), i+batch_size)
# create batch
batch = data[i:i_end]
# create embeddings
chunks = [f'{x["title"]}: {x["content"]}' for x in batch["metadata"]]
embeds = encoder(chunks)
assert len(embeds) == (i_end-i)
to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
# upsert to Pinecone
index.upsert(vectors=to_upsert)
def get_docs(query: str, top_k: int) -> list[str]:
# encode query
xq = encoder([query])
# search pinecone index
res = index.query(vector=xq, top_k=top_k, include_metadata=True)
# get doc text
docs = [x["metadata"]['content'] for x in res["matches"]]
return docs
from groq import Groq
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
def generate(query: str, history):
if not history:
return "History doesn't exist"
# Establish history
for user_msg, bot_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
system_message = (
"Pretend you are a friend that lives in New York City. "
"Please answer while prioritizing fun and unique answers using the "
"context provided below.\n\n"
"CONTEXT:\n"
"\n---\n".join(get_docs(query, top_k=5))
)
messages = [
{"role": "system", "content": system_message},
]
# Add query
messages.append({"role": "user", "content": query})
# generate response
chat_response = groq_client.chat.completions.create(
model="llama3-70b-8192",
messages=messages
)
return chat_response.choices[0].message.content
# Custom CSS for iPhone-style chat
custom_css = """
.gradio-container {
background: transparent !important;
}
.chat-message {
display: flex;
align-items: center;
margin-bottom: 10px;
}
.chat-message.user {
justify-content: flex-end;
}
.chat-message.assistant {
justify-content: flex-start;
}
.chat-bubble {
padding: 10px 15px;
border-radius: 20px;
max-width: 70%;
font-size: 16px;
display: inline-block;
}
.chat-bubble.user {
background-color: #007aff;
color: white;
border-bottom-right-radius: 5px;
}
.chat-bubble.assistant {
background-color: #f0f0f0;
color: black;
border-bottom-left-radius: 5px;
}
.profile-pic {
width: 40px;
height: 40px;
border-radius: 50%;
margin: 0 10px;
}
"""
# Gradio Interface
demo = gr.ChatInterface(generate, css=custom_css)
demo.launch()