mt3842ml commited on
Commit
c45ac2f
·
verified ·
1 Parent(s): 997c8f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -3
app.py CHANGED
@@ -2,10 +2,211 @@ import gradio as gr
2
  import os
3
  from groq import Groq
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def chat_with_groq(user_input, history):
6
- api_key = os.getenv("GROQ_API_KEY") # Set this in your environment variables
7
- if not api_key:
8
- return "Error: API key not found. Set GROQ_API_KEY as an environment variable."
9
 
10
  client = Groq(api_key=api_key)
11
 
 
2
  import os
3
  from groq import Groq
4
 
5
+ ############ TESTING ############
6
+ import pandas as pd
7
+ from datasets import Dataset
8
+
9
+ # Define the dataset schema
10
+ test_dataset_df = pd.DataFrame(columns=['id', 'title', 'content', 'prechunk_id', 'postchunk_id', 'arxiv_id', 'references'])
11
+
12
+ # Populate the dataset with examples
13
+ test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
14
+ 'id': '1',
15
+ 'title': 'Best restaurants in queens',
16
+ 'content': 'I personally like to go to the Pan's Chicken, they have fried chicken and amazing bubble tea.',
17
+ 'prechunk_id': '',
18
+ 'postchunk_id': '2',
19
+ 'arxiv_id': '2401.04088',
20
+ 'references': ['arXiv:9012.3456', 'arXiv:7890.1234']
21
+ }])], ignore_index=True)
22
+
23
+ test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
24
+ 'id': '2',
25
+ 'title': 'Best restaurants in queens',
26
+ 'content': 'if you like asian food, flushing is second to none.',
27
+ 'prechunk_id': '1',
28
+ 'postchunk_id': '3',
29
+ 'arxiv_id': '2401.04088',
30
+ 'references': ['arXiv:6543.2109', 'arXiv:3210.9876']
31
+ }])], ignore_index=True)
32
+
33
+ test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
34
+ 'id': '3',
35
+ 'title': 'Best restaurants in queens',
36
+ 'content': 'you have to try the baked ziti from EC',
37
+ 'prechunk_id': '2',
38
+ 'postchunk_id': '',
39
+ 'arxiv_id': '2401.04088',
40
+ 'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
41
+ }])], ignore_index=True)
42
+
43
+ test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
44
+ 'id': '4',
45
+ 'title': 'Spending a saturday in queens; what to do?',
46
+ 'content': 'theres a hidden gem called The Lounge, you can play poker and blackjack and darts',
47
+ 'prechunk_id': '',
48
+ 'postchunk_id': '5',
49
+ 'arxiv_id': '2401.04088',
50
+ 'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
51
+ }])], ignore_index=True)
52
+
53
+ test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
54
+ 'id': '5',
55
+ 'title': 'Spending a saturday in queens; what to do?',
56
+ 'content': 'if its a nice day, basketball at Non-non-Fiction Park is always fun',
57
+ 'prechunk_id': '',
58
+ 'postchunk_id': '6',
59
+ 'arxiv_id': '2401.04088',
60
+ 'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
61
+ }])], ignore_index=True)
62
+
63
+ test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
64
+ 'id': '7',
65
+ 'title': 'visiting queens for the weekend, how to get around?',
66
+ 'content': 'nothing beats the subway, even with delays its the fastest option. you can transfer between the bus and subway with one swipe',
67
+ 'prechunk_id': '',
68
+ 'postchunk_id': '8',
69
+ 'arxiv_id': '2401.04088',
70
+ 'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
71
+ }])], ignore_index=True)
72
+
73
+ test_dataset_df = pd.concat([test_dataset_df, pd.DataFrame([{
74
+ 'id': '8',
75
+ 'title': 'visiting queens for the weekend, how to get around?',
76
+ 'content': 'if youre going to the bar, its honestly worth ubering there. MTA while drunk isnt something id recommend.',
77
+ 'prechunk_id': '7',
78
+ 'postchunk_id': '',
79
+ 'arxiv_id': '2401.04088',
80
+ 'references': ['arXiv:1234.5678', 'arXiv:9012.3456']
81
+ }])], ignore_index=True)
82
+
83
+ # Convert the DataFrame to a Hugging Face Dataset object
84
+ test_dataset = Dataset.from_pandas(test_dataset_df)
85
+
86
+ data = test_dataset
87
+
88
+ data = data.map(lambda x: {
89
+ "id": x["id"],
90
+ "metadata": {
91
+ "title": x["title"],
92
+ "content": x["content"],
93
+ }
94
+ })
95
+ # drop uneeded columns
96
+ data = data.remove_columns([
97
+ "title", "content", "prechunk_id",
98
+ "postchunk_id", "arxiv_id", "references"
99
+ ])
100
+
101
+ from semantic_router.encoders import HuggingFaceEncoder
102
+
103
+ encoder = HuggingFaceEncoder(name="dwzhu/e5-base-4k")
104
+
105
+ embeds = encoder(["this is a test"])
106
+ dims = len(embeds[0])
107
+
108
+ ############ TESTING ############
109
+
110
+ import os
111
+ import getpass
112
+ from pinecone import Pinecone
113
+ from google.colab import userdata
114
+
115
+ # initialize connection to pinecone (get API key at app.pinecone.io)
116
+ api_key = userdata.get('PINECONE_API_KEY')
117
+
118
+ # configure client
119
+ pc = Pinecone(api_key=api_key)
120
+
121
+ from pinecone import ServerlessSpec
122
+
123
+ spec = ServerlessSpec(
124
+ cloud="aws", region="us-east-1"
125
+ )
126
+
127
+ import time
128
+
129
+ index_name = "groq-llama-3-rag"
130
+ existing_indexes = [
131
+ index_info["name"] for index_info in pc.list_indexes()
132
+ ]
133
+
134
+ # check if index already exists (it shouldn't if this is first time)
135
+ if index_name not in existing_indexes:
136
+ # if does not exist, create index
137
+ pc.create_index(
138
+ index_name,
139
+ dimension=dims,
140
+ metric='cosine',
141
+ spec=spec
142
+ )
143
+ # wait for index to be initialized
144
+ while not pc.describe_index(index_name).status['ready']:
145
+ time.sleep(1)
146
+
147
+ # connect to index
148
+ index = pc.Index(index_name)
149
+ time.sleep(1)
150
+ # view index stats
151
+ index.describe_index_stats()
152
+
153
+ from tqdm.auto import tqdm
154
+
155
+ batch_size = 2 # how many embeddings we create and insert at once
156
+
157
+ for i in tqdm(range(0, len(data), batch_size)):
158
+ # find end of batch
159
+ i_end = min(len(data), i+batch_size)
160
+ # create batch
161
+ batch = data[i:i_end]
162
+ # create embeddings
163
+ chunks = [f'{x["title"]}: {x["content"]}' for x in batch["metadata"]]
164
+ embeds = encoder(chunks)
165
+ assert len(embeds) == (i_end-i)
166
+ to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
167
+ # upsert to Pinecone
168
+ index.upsert(vectors=to_upsert)
169
+
170
+ def get_docs(query: str, top_k: int) -> list[str]:
171
+ # encode query
172
+ xq = encoder([query])
173
+ # search pinecone index
174
+ res = index.query(vector=xq, top_k=top_k, include_metadata=True)
175
+ # get doc text
176
+ docs = [x["metadata"]['content'] for x in res["matches"]]
177
+ return docs
178
+
179
+ from groq import Groq
180
+ from google.colab import userdata
181
+ groq_client = Groq(api_key=userdata.get('1GROQ_API_KEY'))
182
+
183
+ def generate(query: str, docs: list[str]):
184
+ system_message = (
185
+ "Pretend you are a friend that lives in New York City. "
186
+ "Please answer while prioritizing the "
187
+ "context provided below.\n\n"
188
+ "CONTEXT:\n"
189
+ "\n---\n".join(docs)
190
+ )
191
+ messages = [
192
+ {"role": "system", "content": system_message},
193
+ {"role": "user", "content": query}
194
+ ]
195
+ # generate response
196
+ chat_response = groq_client.chat.completions.create(
197
+ model="llama3-70b-8192",
198
+ messages=messages
199
+ )
200
+ return chat_response.choices[0].message.content
201
+
202
+ query = "Favorite drink in Queens?"
203
+ docs = get_docs(query, top_k=5)
204
+ out = generate(query=query, docs=docs)
205
+
206
+
207
+ #### to implement
208
+
209
  def chat_with_groq(user_input, history):
 
 
 
210
 
211
  client = Groq(api_key=api_key)
212