Update app.py
Browse files
app.py
CHANGED
@@ -2,10 +2,211 @@ import gradio as gr
|
|
2 |
import os
|
3 |
from groq import Groq
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def chat_with_groq(user_input, history):
|
6 |
-
api_key = os.getenv("GROQ_API_KEY") # Set this in your environment variables
|
7 |
-
if not api_key:
|
8 |
-
return "Error: API key not found. Set GROQ_API_KEY as an environment variable."
|
9 |
|
10 |
client = Groq(api_key=api_key)
|
11 |
|
|
|
2 |
import os
|
3 |
from groq import Groq
|
4 |
|
5 |
+
############ TESTING ############
|
6 |
+
import pandas as pd
|
7 |
+
from datasets import Dataset
|
8 |
+
|
9 |
+
# Define the dataset schema
|
10 |
+
test_dataset_df = pd.DataFrame(columns=['id', 'title', 'content', 'prechunk_id', 'postchunk_id', 'arxiv_id', 'references'])
|
11 |
+
|
12 |
+
# Populate the dataset with examples
|
13 |
+
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
|
14 |
+
'id': '1',
|
15 |
+
'title': 'Best restaurants in queens',
|
16 |
+
'content': 'I personally like to go to the Pan's Chicken, they have fried chicken and amazing bubble tea.',
|
17 |
+
'prechunk_id': '',
|
18 |
+
'postchunk_id': '2',
|
19 |
+
'arxiv_id': '2401.04088',
|
20 |
+
'references': ['arXiv:9012.3456', 'arXiv:7890.1234']
|
21 |
+
}])], ignore_index=True)
|
22 |
+
|
23 |
+
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
|
24 |
+
'id': '2',
|
25 |
+
'title': 'Best restaurants in queens',
|
26 |
+
'content': 'if you like asian food, flushing is second to none.',
|
27 |
+
'prechunk_id': '1',
|
28 |
+
'postchunk_id': '3',
|
29 |
+
'arxiv_id': '2401.04088',
|
30 |
+
'references': ['arXiv:6543.2109', 'arXiv:3210.9876']
|
31 |
+
}])], ignore_index=True)
|
32 |
+
|
33 |
+
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
|
34 |
+
'id': '3',
|
35 |
+
'title': 'Best restaurants in queens',
|
36 |
+
'content': 'you have to try the baked ziti from EC',
|
37 |
+
'prechunk_id': '2',
|
38 |
+
'postchunk_id': '',
|
39 |
+
'arxiv_id': '2401.04088',
|
40 |
+
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
|
41 |
+
}])], ignore_index=True)
|
42 |
+
|
43 |
+
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
|
44 |
+
'id': '4',
|
45 |
+
'title': 'Spending a saturday in queens; what to do?',
|
46 |
+
'content': 'theres a hidden gem called The Lounge, you can play poker and blackjack and darts',
|
47 |
+
'prechunk_id': '',
|
48 |
+
'postchunk_id': '5',
|
49 |
+
'arxiv_id': '2401.04088',
|
50 |
+
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
|
51 |
+
}])], ignore_index=True)
|
52 |
+
|
53 |
+
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
|
54 |
+
'id': '5',
|
55 |
+
'title': 'Spending a saturday in queens; what to do?',
|
56 |
+
'content': 'if its a nice day, basketball at Non-non-Fiction Park is always fun',
|
57 |
+
'prechunk_id': '',
|
58 |
+
'postchunk_id': '6',
|
59 |
+
'arxiv_id': '2401.04088',
|
60 |
+
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
|
61 |
+
}])], ignore_index=True)
|
62 |
+
|
63 |
+
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
|
64 |
+
'id': '7',
|
65 |
+
'title': 'visiting queens for the weekend, how to get around?',
|
66 |
+
'content': 'nothing beats the subway, even with delays its the fastest option. you can transfer between the bus and subway with one swipe',
|
67 |
+
'prechunk_id': '',
|
68 |
+
'postchunk_id': '8',
|
69 |
+
'arxiv_id': '2401.04088',
|
70 |
+
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
|
71 |
+
}])], ignore_index=True)
|
72 |
+
|
73 |
+
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
|
74 |
+
'id': '8',
|
75 |
+
'title': 'visiting queens for the weekend, how to get around?',
|
76 |
+
'content': 'if youre going to the bar, its honestly worth ubering there. MTA while drunk isnt something id recommend.',
|
77 |
+
'prechunk_id': '7',
|
78 |
+
'postchunk_id': '',
|
79 |
+
'arxiv_id': '2401.04088',
|
80 |
+
'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
|
81 |
+
}])], ignore_index=True)
|
82 |
+
|
83 |
+
# Convert the DataFrame to a Hugging Face Dataset object
|
84 |
+
test_dataset = Dataset.from_pandas(test_dataset_df)
|
85 |
+
|
86 |
+
data = test_dataset
|
87 |
+
|
88 |
+
data = data.map(lambda x: {
|
89 |
+
"id": x["id"],
|
90 |
+
"metadata": {
|
91 |
+
"title": x["title"],
|
92 |
+
"content": x["content"],
|
93 |
+
}
|
94 |
+
})
|
95 |
+
# drop uneeded columns
|
96 |
+
data = data.remove_columns([
|
97 |
+
"title", "content", "prechunk_id",
|
98 |
+
"postchunk_id", "arxiv_id", "references"
|
99 |
+
])
|
100 |
+
|
101 |
+
from semantic_router.encoders import HuggingFaceEncoder
|
102 |
+
|
103 |
+
encoder = HuggingFaceEncoder(name="dwzhu/e5-base-4k")
|
104 |
+
|
105 |
+
embeds = encoder(["this is a test"])
|
106 |
+
dims = len(embeds[0])
|
107 |
+
|
108 |
+
############ TESTING ############
|
109 |
+
|
110 |
+
import os
|
111 |
+
import getpass
|
112 |
+
from pinecone import Pinecone
|
113 |
+
from google.colab import userdata
|
114 |
+
|
115 |
+
# initialize connection to pinecone (get API key at app.pinecone.io)
|
116 |
+
api_key = userdata.get('PINECONE_API_KEY')
|
117 |
+
|
118 |
+
# configure client
|
119 |
+
pc = Pinecone(api_key=api_key)
|
120 |
+
|
121 |
+
from pinecone import ServerlessSpec
|
122 |
+
|
123 |
+
spec = ServerlessSpec(
|
124 |
+
cloud="aws", region="us-east-1"
|
125 |
+
)
|
126 |
+
|
127 |
+
import time
|
128 |
+
|
129 |
+
index_name = "groq-llama-3-rag"
|
130 |
+
existing_indexes = [
|
131 |
+
index_info["name"] for index_info in pc.list_indexes()
|
132 |
+
]
|
133 |
+
|
134 |
+
# check if index already exists (it shouldn't if this is first time)
|
135 |
+
if index_name not in existing_indexes:
|
136 |
+
# if does not exist, create index
|
137 |
+
pc.create_index(
|
138 |
+
index_name,
|
139 |
+
dimension=dims,
|
140 |
+
metric='cosine',
|
141 |
+
spec=spec
|
142 |
+
)
|
143 |
+
# wait for index to be initialized
|
144 |
+
while not pc.describe_index(index_name).status['ready']:
|
145 |
+
time.sleep(1)
|
146 |
+
|
147 |
+
# connect to index
|
148 |
+
index = pc.Index(index_name)
|
149 |
+
time.sleep(1)
|
150 |
+
# view index stats
|
151 |
+
index.describe_index_stats()
|
152 |
+
|
153 |
+
from tqdm.auto import tqdm
|
154 |
+
|
155 |
+
batch_size = 2 # how many embeddings we create and insert at once
|
156 |
+
|
157 |
+
for i in tqdm(range(0, len(data), batch_size)):
|
158 |
+
# find end of batch
|
159 |
+
i_end = min(len(data), i+batch_size)
|
160 |
+
# create batch
|
161 |
+
batch = data[i:i_end]
|
162 |
+
# create embeddings
|
163 |
+
chunks = [f'{x["title"]}: {x["content"]}' for x in batch["metadata"]]
|
164 |
+
embeds = encoder(chunks)
|
165 |
+
assert len(embeds) == (i_end-i)
|
166 |
+
to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
|
167 |
+
# upsert to Pinecone
|
168 |
+
index.upsert(vectors=to_upsert)
|
169 |
+
|
170 |
+
def get_docs(query: str, top_k: int) -> list[str]:
|
171 |
+
# encode query
|
172 |
+
xq = encoder([query])
|
173 |
+
# search pinecone index
|
174 |
+
res = index.query(vector=xq, top_k=top_k, include_metadata=True)
|
175 |
+
# get doc text
|
176 |
+
docs = [x["metadata"]['content'] for x in res["matches"]]
|
177 |
+
return docs
|
178 |
+
|
179 |
+
from groq import Groq
|
180 |
+
from google.colab import userdata
|
181 |
+
groq_client = Groq(api_key=userdata.get('1GROQ_API_KEY'))
|
182 |
+
|
183 |
+
def generate(query: str, docs: list[str]):
|
184 |
+
system_message = (
|
185 |
+
"Pretend you are a friend that lives in New York City. "
|
186 |
+
"Please answer while prioritizing the "
|
187 |
+
"context provided below.\n\n"
|
188 |
+
"CONTEXT:\n"
|
189 |
+
"\n---\n".join(docs)
|
190 |
+
)
|
191 |
+
messages = [
|
192 |
+
{"role": "system", "content": system_message},
|
193 |
+
{"role": "user", "content": query}
|
194 |
+
]
|
195 |
+
# generate response
|
196 |
+
chat_response = groq_client.chat.completions.create(
|
197 |
+
model="llama3-70b-8192",
|
198 |
+
messages=messages
|
199 |
+
)
|
200 |
+
return chat_response.choices[0].message.content
|
201 |
+
|
202 |
+
query = "Favorite drink in Queens?"
|
203 |
+
docs = get_docs(query, top_k=5)
|
204 |
+
out = generate(query=query, docs=docs)
|
205 |
+
|
206 |
+
|
207 |
+
#### to implement
|
208 |
+
|
209 |
def chat_with_groq(user_input, history):
|
|
|
|
|
|
|
210 |
|
211 |
client = Groq(api_key=api_key)
|
212 |
|