File size: 4,952 Bytes
dd73f68
262247c
480b1f1
262247c
 
c45ac2f
 
 
262247c
c45ac2f
 
 
262247c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c45ac2f
262247c
 
 
c45ac2f
262247c
 
 
 
 
 
 
 
 
 
c45ac2f
 
 
262247c
c45ac2f
262247c
 
 
c45ac2f
262247c
 
 
c45ac2f
262247c
c45ac2f
262247c
 
 
 
 
 
c45ac2f
262247c
 
 
47ac53b
262247c
9e45e7f
0459a8f
 
 
 
 
262247c
 
 
 
 
 
9e45e7f
262247c
 
 
9e45e7f
a49c212
262247c
9e45e7f
 
 
 
262247c
 
 
 
 
 
 
1606bed
262247c
1606bed
262247c
 
 
 
 
 
 
 
48cbcfa
262247c
997c8f4
 
262247c
997c8f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262247c
997c8f4
 
 
262247c
 
 
 
 
 
997c8f4
 
262247c
0eea88d
48cbcfa
22a66e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import gradio as gr
import os
from groq import Groq
import pandas as pd
from datasets import Dataset
from semantic_router.encoders import HuggingFaceEncoder

encoder = HuggingFaceEncoder(name="dwzhu/e5-base-4k")

embeds = encoder(["this is a test"])
dims = len(embeds[0])

############ TESTING ############

import os
import getpass
from pinecone import Pinecone

# initialize connection to pinecone (get API key at app.pinecone.io)
api_key = os.getenv("PINECONE_API_KEY")

# configure client
pc = Pinecone(api_key=api_key)

from pinecone import ServerlessSpec

spec = ServerlessSpec(
    cloud="aws", region="us-east-1"
)

import time

index_name = "groq-llama-3-rag"
existing_indexes = [
    index_info["name"] for index_info in pc.list_indexes()
]

# check if index already exists (it shouldn't if this is first time)
if index_name not in existing_indexes:
    # if does not exist, create index
    pc.create_index(
        index_name,
        dimension=dims,
        metric='cosine',
        spec=spec
    )
    # wait for index to be initialized
    while not pc.describe_index(index_name).status['ready']:
        time.sleep(1)

# connect to index
index = pc.Index(index_name)
time.sleep(1)
# view index stats
index.describe_index_stats()


def get_docs(query: str, top_k: int) -> list[str]:
    # encode query
    xq = encoder([query])
    # search pinecone index
    res = index.query(vector=xq, top_k=top_k, include_metadata=True)
    # get doc text
    docs = [x["metadata"]['content_snippet'] for x in res["matches"]]
    return docs

from groq import Groq
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))

def generate(query: str, history):

    # Create system message
    if not history:
        print("Initialization!")
        system_message = (
            "You are a friendly and knowledgeable New Yorker who loves sharing recommendations about the city. "
            "You have lived in NYC for years and know both the famous tourist spots and hidden local gems. "
            "Your goal is to give recommendations tailored to what the user is asking for, whether they want iconic attractions "
            "or lesser-known spots loved by locals.\n\n"
            "Use the provided context to enhance your responses with real local insights, but only include details that are relevant "
            "to the user’s question. If the context provides useful recommendations that match what the user is asking for, use them. "
            "If the context is unrelated or does not fully answer the question, rely on your general NYC knowledge instead.\n\n"
            "Be specific when recommending places—mention neighborhoods, the atmosphere, and why someone might like a spot. "
            "Keep your tone warm, conversational, and engaging, like a close friend who genuinely enjoys sharing their city.\n\n"
            "CONTEXT:\n"
            "\n---\n".join(get_docs(query, top_k=5))
        )
        messages = [
            {"role": "system", "content": system_message},
        ]
    else:
        print("History:\n" + str(history))
        # Establish history
        messages = []
        for user_msg, bot_msg in history:
            messages.append({"role": "user", "content": user_msg})
            messages.append({"role": "assistant", "content": bot_msg})
            messages.append({"role": "assistant", "content": bot_msg})
        system_message = (
            "Here is additional context based on the newest query.\n\n"
            "CONTEXT:\n"
            "\n---\n".join(get_docs(query, top_k=5))
        )
        messages.append({"role": "system", "content": system_message})

    # Add query
    messages.append({"role": "user", "content": query})
    
    # generate response
    chat_response = groq_client.chat.completions.create(
        model="llama3-70b-8192",
        messages=messages
    )
    return chat_response.choices[0].message.content


# Custom CSS for iPhone-style chat
custom_css = """
.gradio-container {
    background: transparent !important;
}
.chat-message {
    display: flex;
    align-items: center;
    margin-bottom: 10px;
}
.chat-message.user {
    justify-content: flex-end;
}
.chat-message.assistant {
    justify-content: flex-start;
}
.chat-bubble {
    padding: 10px 15px;
    border-radius: 20px;
    max-width: 70%;
    font-size: 16px;
    display: inline-block;
}
.chat-bubble.user {
    background-color: #007aff;
    color: white;
    border-bottom-right-radius: 5px;
}
.chat-bubble.assistant {
    background-color: #f0f0f0;
    color: black;
    border-bottom-left-radius: 5px;
}
.profile-pic {
    width: 40px;
    height: 40px;
    border-radius: 50%;
    margin: 0 10px;
}
"""

# Gradio Interface
demo = gr.ChatInterface(generate, css=custom_css, textbox = gr.Textbox(placeholder="Ask me anything about NYC!"), chatbot=gr.Chatbot(placeholder="<strong>NYC Buddy</strong><br>Looking for local tips, hidden gems, or iconic spots? Just ask!"))

demo.launch()