File size: 8,970 Bytes
48cbcfa 480b1f1 48cbcfa c45ac2f 5c3f278 c45ac2f 776c160 d8a0fee c45ac2f 7e2a637 c45ac2f 7e2a637 c45ac2f 1606bed d728bc5 47ac53b 9e45e7f 0459a8f 5c3f278 9e45e7f d728bc5 9e45e7f d728bc5 9e45e7f 1606bed c45ac2f 48cbcfa 997c8f4 22a66e0 1606bed 48cbcfa 22a66e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 |
import gradio as gr
import os
from groq import Groq
############ TESTING ############
import pandas as pd
from datasets import Dataset
# Define the dataset schema
test_dataset_df = pd.DataFrame(columns=['id', 'title', 'content', 'prechunk_id', 'postchunk_id', 'arxiv_id', 'references'])
# Populate the dataset with examples
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '1',
'title': 'Best restaurants in queens',
'content': 'I personally like to go to the J-Pan Chicken, they have fried chicken and amazing bubble tea.',
'prechunk_id': '',
'postchunk_id': '2',
'arxiv_id': '2401.04088',
'references': ['arXiv:9012.3456', 'arXiv:7890.1234']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '2',
'title': 'Best restaurants in queens',
'content': 'if you like asian food, flushing is second to none.',
'prechunk_id': '1',
'postchunk_id': '3',
'arxiv_id': '2401.04088',
'references': ['arXiv:6543.2109', 'arXiv:3210.9876']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '3',
'title': 'Best restaurants in queens',
'content': 'you have to try the ziti from ECC',
'prechunk_id': '2',
'postchunk_id': '',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '6',
'title': 'Best restaurants in queens',
'content': 'theres a good halal cart on Wub Street, they give extra sticky creamy white sauce',
'prechunk_id': '',
'postchunk_id': '',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '4',
'title': 'Spending a saturday in queens; what to do?',
'content': 'theres a hidden gem called The Lounge, you can play poker and blackjack and darts',
'prechunk_id': '',
'postchunk_id': '5',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '5',
'title': 'Spending a saturday in queens; what to do?',
'content': 'if its a nice day, basketball at Non-non-Fiction Park is always fun',
'prechunk_id': '',
'postchunk_id': '6',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '7',
'title': 'visiting queens for the weekend, how to get around?',
'content': 'nothing beats the subway, even with delays its the fastest option. you can transfer between the bus and subway with one swipe',
'prechunk_id': '',
'postchunk_id': '8',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
'id': '8',
'title': 'visiting queens for the weekend, how to get around?',
'content': 'if youre going to the bar, its honestly worth ubering there. MTA while drunk isnt something id recommend.',
'prechunk_id': '7',
'postchunk_id': '',
'arxiv_id': '2401.04088',
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)
# Convert the DataFrame to a Hugging Face Dataset object
test_dataset = Dataset.from_pandas(test_dataset_df)
data = test_dataset
data = data.map(lambda x: {
"id": x["id"],
"metadata": {
"title": x["title"],
"content": x["content"],
}
})
# drop uneeded columns
data = data.remove_columns([
"title", "content", "prechunk_id",
"postchunk_id", "arxiv_id", "references"
])
from semantic_router.encoders import HuggingFaceEncoder
encoder = HuggingFaceEncoder(name="dwzhu/e5-base-4k")
embeds = encoder(["this is a test"])
dims = len(embeds[0])
############ TESTING ############
import os
import getpass
from pinecone import Pinecone
# initialize connection to pinecone (get API key at app.pinecone.io)
api_key = os.getenv("PINECONE_API_KEY")
# configure client
pc = Pinecone(api_key=api_key)
from pinecone import ServerlessSpec
spec = ServerlessSpec(
cloud="aws", region="us-east-1"
)
import time
index_name = "groq-llama-3-rag"
existing_indexes = [
index_info["name"] for index_info in pc.list_indexes()
]
# check if index already exists (it shouldn't if this is first time)
if index_name not in existing_indexes:
# if does not exist, create index
pc.create_index(
index_name,
dimension=dims,
metric='cosine',
spec=spec
)
# wait for index to be initialized
while not pc.describe_index(index_name).status['ready']:
time.sleep(1)
# connect to index
index = pc.Index(index_name)
time.sleep(1)
# view index stats
index.describe_index_stats()
from tqdm.auto import tqdm
batch_size = 2 # how many embeddings we create and insert at once
for i in tqdm(range(0, len(data), batch_size)):
# find end of batch
i_end = min(len(data), i+batch_size)
# create batch
batch = data[i:i_end]
# create embeddings
chunks = [f'{x["title"]}: {x["content"]}' for x in batch["metadata"]]
embeds = encoder(chunks)
assert len(embeds) == (i_end-i)
to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
# upsert to Pinecone
index.upsert(vectors=to_upsert)
def get_docs(query: str, top_k: int) -> list[str]:
# encode query
xq = encoder([query])
# search pinecone index
res = index.query(vector=xq, top_k=top_k, include_metadata=True)
# get doc text
docs = [x["metadata"]['content'] for x in res["matches"]]
return docs
from groq import Groq
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
def generate(query: str, history):
# Create system message
if not history:
system_message = (
"You are a friendly and knowledgeable New Yorker who loves sharing recommendations about the city. "
"You have lived in NYC for years and know both the famous tourist spots and hidden local gems. "
"Your goal is to give recommendations tailored to what the user is asking for, whether they want iconic attractions "
"or lesser-known spots loved by locals.\n\n"
"Use the provided context to enhance your responses with real local insights, but only include details that are relevant "
"to the user’s question. If the context provides useful recommendations that match what the user is asking for, use them. "
"If the context is unrelated or does not fully answer the question, rely on your general NYC knowledge instead.\n\n"
"Be specific when recommending places—mention neighborhoods, the atmosphere, and why someone might like a spot. "
"Keep your tone warm, conversational, and engaging, like a close friend who genuinely enjoys sharing their city.\n\n"
"CONTEXT:\n"
"\n---\n".join(get_docs(query, top_k=5))
)
messages = [
{"role": "system", "content": system_message},
]
else:
# Establish history
messages = []
for user_msg, bot_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "assistant", "content": bot_msg})
system_message = (
"Here is additional context based on the newest query.\n\n"
"CONTEXT:\n"
"\n---\n".join(get_docs(query, top_k=5))
)
messages.append({"role": "system", "content": system_message})
# Add query
messages.append({"role": "user", "content": query})
# generate response
chat_response = groq_client.chat.completions.create(
model="llama3-70b-8192",
messages=messages
)
return chat_response.choices[0].message.content
# Custom CSS for iPhone-style chat
custom_css = """
.gradio-container {
background: transparent !important;
}
.chat-message {
display: flex;
align-items: center;
margin-bottom: 10px;
}
.chat-message.user {
justify-content: flex-end;
}
.chat-message.assistant {
justify-content: flex-start;
}
.chat-bubble {
padding: 10px 15px;
border-radius: 20px;
max-width: 70%;
font-size: 16px;
display: inline-block;
}
.chat-bubble.user {
background-color: #007aff;
color: white;
border-bottom-right-radius: 5px;
}
.chat-bubble.assistant {
background-color: #f0f0f0;
color: black;
border-bottom-left-radius: 5px;
}
.profile-pic {
width: 40px;
height: 40px;
border-radius: 50%;
margin: 0 10px;
}
"""
# Gradio Interface
demo = gr.ChatInterface(generate, css=custom_css)
demo.launch()
|