File size: 8,970 Bytes
48cbcfa
480b1f1
 
48cbcfa
c45ac2f
 
 
 
 
 
 
 
 
 
 
5c3f278
c45ac2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776c160
 
 
 
 
 
 
 
 
 
d8a0fee
 
c45ac2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e2a637
c45ac2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e2a637
c45ac2f
1606bed
 
d728bc5
47ac53b
9e45e7f
0459a8f
 
 
 
 
 
 
 
5c3f278
9e45e7f
 
 
d728bc5
 
 
9e45e7f
 
 
 
 
 
 
 
d728bc5
 
 
9e45e7f
 
1606bed
 
 
 
c45ac2f
 
 
 
 
 
 
48cbcfa
997c8f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22a66e0
1606bed
48cbcfa
22a66e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import gradio as gr
import os
from groq import Groq

############ TESTING ############
import pandas as pd
from datasets import Dataset

# Define the dataset schema
test_dataset_df = pd.DataFrame(columns=['id', 'title', 'content', 'prechunk_id', 'postchunk_id', 'arxiv_id', 'references'])

# Populate the dataset with examples
test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
    'id': '1',
    'title': 'Best restaurants in queens',
    'content': 'I personally like to go to the J-Pan Chicken, they have fried chicken and amazing bubble tea.',
    'prechunk_id': '',
    'postchunk_id': '2',
    'arxiv_id': '2401.04088',
    'references': ['arXiv:9012.3456', 'arXiv:7890.1234']
}])], ignore_index=True)

test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
    'id': '2',
    'title': 'Best restaurants in queens',
    'content': 'if you like asian food, flushing is second to none.',
    'prechunk_id': '1',
    'postchunk_id': '3',
    'arxiv_id': '2401.04088',
    'references': ['arXiv:6543.2109', 'arXiv:3210.9876']
}])], ignore_index=True)

test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
    'id': '3',
    'title': 'Best restaurants in queens',
    'content': 'you have to try the ziti from ECC',
    'prechunk_id': '2',
    'postchunk_id': '',
    'arxiv_id': '2401.04088',
    'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)

test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
    'id': '6',
    'title': 'Best restaurants in queens',
    'content': 'theres a good halal cart on Wub Street, they give extra sticky creamy white sauce',
    'prechunk_id': '',
    'postchunk_id': '',
    'arxiv_id': '2401.04088',
    'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)

test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
    'id': '4',
    'title': 'Spending a saturday in queens; what to do?',
    'content': 'theres a hidden gem called The Lounge, you can play poker and blackjack and darts',
    'prechunk_id': '',
    'postchunk_id': '5',
    'arxiv_id': '2401.04088',
    'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)

test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
    'id': '5',
    'title': 'Spending a saturday in queens; what to do?',
    'content': 'if its a nice day, basketball at Non-non-Fiction Park is always fun',
    'prechunk_id': '',
    'postchunk_id': '6',
    'arxiv_id': '2401.04088',
    'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)

test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
    'id': '7',
    'title': 'visiting queens for the weekend, how to get around?',
    'content': 'nothing beats the subway, even with delays its the fastest option. you can transfer between the bus and subway with one swipe',
    'prechunk_id': '',
    'postchunk_id': '8',
    'arxiv_id': '2401.04088',
    'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)

test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
    'id': '8',
    'title': 'visiting queens for the weekend, how to get around?',
    'content': 'if youre going to the bar, its honestly worth ubering there. MTA while drunk isnt something id recommend.',
    'prechunk_id': '7',
    'postchunk_id': '',
    'arxiv_id': '2401.04088',
    'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
}])], ignore_index=True)

# Convert the DataFrame to a Hugging Face Dataset object
test_dataset = Dataset.from_pandas(test_dataset_df)

data = test_dataset

data = data.map(lambda x: {
    "id": x["id"],
    "metadata": {
        "title": x["title"],
        "content": x["content"],
    }
})
# drop uneeded columns
data = data.remove_columns([
    "title", "content", "prechunk_id",
    "postchunk_id", "arxiv_id", "references"
])

from semantic_router.encoders import HuggingFaceEncoder

encoder = HuggingFaceEncoder(name="dwzhu/e5-base-4k")

embeds = encoder(["this is a test"])
dims = len(embeds[0])

############ TESTING ############

import os
import getpass
from pinecone import Pinecone

# initialize connection to pinecone (get API key at app.pinecone.io)
api_key = os.getenv("PINECONE_API_KEY")

# configure client
pc = Pinecone(api_key=api_key)

from pinecone import ServerlessSpec

spec = ServerlessSpec(
    cloud="aws", region="us-east-1"
)

import time

index_name = "groq-llama-3-rag"
existing_indexes = [
    index_info["name"] for index_info in pc.list_indexes()
]

# check if index already exists (it shouldn't if this is first time)
if index_name not in existing_indexes:
    # if does not exist, create index
    pc.create_index(
        index_name,
        dimension=dims,
        metric='cosine',
        spec=spec
    )
    # wait for index to be initialized
    while not pc.describe_index(index_name).status['ready']:
        time.sleep(1)

# connect to index
index = pc.Index(index_name)
time.sleep(1)
# view index stats
index.describe_index_stats()

from tqdm.auto import tqdm

batch_size = 2  # how many embeddings we create and insert at once

for i in tqdm(range(0, len(data), batch_size)):
    # find end of batch
    i_end = min(len(data), i+batch_size)
    # create batch
    batch = data[i:i_end]
    # create embeddings
    chunks = [f'{x["title"]}: {x["content"]}' for x in batch["metadata"]]
    embeds = encoder(chunks)
    assert len(embeds) == (i_end-i)
    to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
    # upsert to Pinecone
    index.upsert(vectors=to_upsert)

def get_docs(query: str, top_k: int) -> list[str]:
    # encode query
    xq = encoder([query])
    # search pinecone index
    res = index.query(vector=xq, top_k=top_k, include_metadata=True)
    # get doc text
    docs = [x["metadata"]['content'] for x in res["matches"]]
    return docs

from groq import Groq
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))

def generate(query: str, history):

    # Create system message
    if not history:
        system_message = (
            "You are a friendly and knowledgeable New Yorker who loves sharing recommendations about the city. "
            "You have lived in NYC for years and know both the famous tourist spots and hidden local gems. "
            "Your goal is to give recommendations tailored to what the user is asking for, whether they want iconic attractions "
            "or lesser-known spots loved by locals.\n\n"
            "Use the provided context to enhance your responses with real local insights, but only include details that are relevant "
            "to the user’s question. If the context provides useful recommendations that match what the user is asking for, use them. "
            "If the context is unrelated or does not fully answer the question, rely on your general NYC knowledge instead.\n\n"
            "Be specific when recommending places—mention neighborhoods, the atmosphere, and why someone might like a spot. "
            "Keep your tone warm, conversational, and engaging, like a close friend who genuinely enjoys sharing their city.\n\n"
            "CONTEXT:\n"
            "\n---\n".join(get_docs(query, top_k=5))
        )
        messages = [
            {"role": "system", "content": system_message},
        ]
    else:
        # Establish history
        messages = []
        for user_msg, bot_msg in history:
            messages.append({"role": "user", "content": user_msg})
            messages.append({"role": "assistant", "content": bot_msg})
            messages.append({"role": "assistant", "content": bot_msg})
        system_message = (
            "Here is additional context based on the newest query.\n\n"
            "CONTEXT:\n"
            "\n---\n".join(get_docs(query, top_k=5))
        )
        messages.append({"role": "system", "content": system_message})

    # Add query
    messages.append({"role": "user", "content": query})
    
    # generate response
    chat_response = groq_client.chat.completions.create(
        model="llama3-70b-8192",
        messages=messages
    )
    return chat_response.choices[0].message.content


# Custom CSS for iPhone-style chat
custom_css = """
.gradio-container {
    background: transparent !important;
}
.chat-message {
    display: flex;
    align-items: center;
    margin-bottom: 10px;
}
.chat-message.user {
    justify-content: flex-end;
}
.chat-message.assistant {
    justify-content: flex-start;
}
.chat-bubble {
    padding: 10px 15px;
    border-radius: 20px;
    max-width: 70%;
    font-size: 16px;
    display: inline-block;
}
.chat-bubble.user {
    background-color: #007aff;
    color: white;
    border-bottom-right-radius: 5px;
}
.chat-bubble.assistant {
    background-color: #f0f0f0;
    color: black;
    border-bottom-left-radius: 5px;
}
.profile-pic {
    width: 40px;
    height: 40px;
    border-radius: 50%;
    margin: 0 10px;
}
"""

# Gradio Interface
demo = gr.ChatInterface(generate, css=custom_css)

demo.launch()