tonic commited on
Commit
6448a30
·
1 Parent(s): baedf33

Create app.py

Browse files
Files changed (1) hide show
  1. backend/app.py +183 -0
backend/app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import weaviate
2
+ import langchain
3
+ import gradio as gr
4
+ from langchain.embeddings import CohereEmbeddings
5
+ from langchain.memory import ConversationBufferMemory
6
+ from langchain.prompts.prompt import PromptTemplate
7
+ from langchain.document_loaders import UnstructuredFileLoader
8
+ from langchain.vectorstores import Weaviate
9
+ from langchain.llms import OpenAI
10
+ from langchain.chains import RetrievalQA
11
+ import os
12
+ import urllib.request
13
+ import ssl
14
+ import mimetypes
15
+ from dotenv import load_dotenv
16
+ import cohere
17
+
18
+ # Load environment variables
19
+ load_dotenv()
20
+ openai_api_key = os.getenv('OPENAI')
21
+ cohere_api_key = os.getenv('COHERE')
22
+ weaviate_api_key = os.getenv('WEAVIATE')
23
+ weaviate_url = os.getenv('WEAVIATE_URL')
24
+
25
+ # Define your prompt templates
26
+ prompt_template = """
27
+ your preferred texts.
28
+
29
+ {context}
30
+
31
+ {chat_history}
32
+ Human: {human_input}
33
+ Chatbot:
34
+ """
35
+
36
+ summary_prompt_template = """
37
+ Current summary:
38
+ {summary}
39
+
40
+ new lines of conversation:
41
+ {new_lines}
42
+
43
+ New summary:
44
+ """
45
+
46
+ # Initialize chat history
47
+ chat_history = ChatMessageHistory.construct()
48
+
49
+ # Create prompt templates
50
+ summary_prompt = PromptTemplate(input_variables=["summary", "new_lines"], template=summary_prompt_template)
51
+ load_qa_chain_prompt = PromptTemplate(input_variables=["chat_history", "human_input", "context"], template=prompt_template)
52
+
53
+ # Initialize memory
54
+ memory = ConversationSummaryBufferMemory(
55
+ llm="your llm",
56
+ memory_key="chat_history",
57
+ input_key="human_input",
58
+ max_token=5000,
59
+ prompt=summary_prompt,
60
+ moving_summary_buffer="summary",
61
+ chat_memory=chat_history
62
+ )
63
+
64
+ # Load QA chain with memory
65
+ qa_chain = load_qa_chain(llm="your llm", chain_type="stuff", memory=memory, prompt=load_qa_chain_prompt)
66
+
67
+ # Weaviate connection
68
+ auth_config = weaviate.auth.AuthApiKey(api_key=weaviate_api_key)
69
+ client = weaviate.Client(url=weaviate_url, auth_client_secret=auth_config,
70
+ additional_headers={"X-Cohere-Api-Key": cohere_api_key})
71
+
72
+ # Initialize vectorstore
73
+ vectorstore = Weaviate(client, index_name="HereChat", text_key="text")
74
+ vectorstore._query_attrs = ["text", "title", "url", "views", "lang", "_additional {distance}"]
75
+ vectorstore.embedding = CohereEmbeddings(model="embed-multilingual-v2.0", cohere_api_key=cohere_api_key)
76
+
77
+ # Initialize Cohere client
78
+ co = cohere.Client(api_key=cohere_api_key)
79
+
80
+ def embed_pdf(file, collection_name):
81
+ # Save the uploaded file
82
+ filename = file.name
83
+ file_path = os.path.join('./', filename)
84
+ with open(file_path, 'wb') as f:
85
+ f.write(file.read())
86
+
87
+ # Checking filetype for document parsing
88
+ mime_type = mimetypes.guess_type(file_path)[0]
89
+ loader = UnstructuredFileLoader(file_path)
90
+ docs = loader.load()
91
+
92
+ # Generate embeddings and store documents in Weaviate
93
+ embeddings = CohereEmbeddings(model="embed-multilingual-v2.0", cohere_api_key=cohere_api_key)
94
+ for doc in docs:
95
+ embedding = embeddings.embed([doc['text']])
96
+ weaviate_document = {
97
+ "text": doc['text'],
98
+ "embedding": embedding
99
+ }
100
+ client.data_object.create(data_object=weaviate_document, class_name=collection_name)
101
+
102
+ os.remove(file_path)
103
+ return {"message": f"Documents embedded in Weaviate collection '{collection_name}'"}
104
+
105
+ def update_chat_history(user_message, ai_message):
106
+ chat_history.add_user_message(user_message)
107
+ chat_history.add_ai_message(ai_message)
108
+ # Update memory if needed
109
+ if len(chat_history) > memory.max_token:
110
+ memory.create_summary()
111
+
112
+ def retrieve_info(query):
113
+ update_chat_history(query, "")
114
+ llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
115
+ qa = RetrievalQA.from_chain_type(llm, retriever=vectorstore.as_retriever())
116
+
117
+ # Retrieve initial results
118
+ initial_results = qa({"query": query})
119
+
120
+ # Assuming initial_results are in the desired format, extract the top documents
121
+ top_docs = initial_results[:25] # Adjust this if your result format is different
122
+
123
+ # Rerank the top results
124
+ reranked_results = co.rerank(query=query, documents=top_docs, top_n=3, model='rerank-english-v2.0')
125
+
126
+ # Format the reranked results
127
+ formatted_results = []
128
+ for idx, r in enumerate(reranked_results):
129
+ formatted_result = {
130
+ "Document Rank": idx + 1,
131
+ "Document Index": r.index,
132
+ "Document": r.document['text'],
133
+ "Relevance Score": f"{r.relevance_score:.2f}"
134
+ }
135
+ formatted_results.append(formatted_result)
136
+
137
+ return {"results": formatted_results}
138
+ # Format the reranked results and append to user prompt
139
+ user_prompt = f"User: {query}\n"
140
+ for idx, r in enumerate(reranked_results):
141
+ user_prompt += f"Document {idx + 1}: {r.document['text']}\nRelevance Score: {r.relevance_score:.2f}\n\n"
142
+
143
+ # Final API call to OpenAI
144
+ final_response = client.chat.completions.create(
145
+ model="gpt-4-1106-preview",
146
+ messages=[
147
+ {
148
+ "role": "system",
149
+ "content": "You are a redditor. Assess, rephrase, and explain the following. Provide long answers. Use the same words and language you receive."
150
+ },
151
+ {
152
+ "role": "user",
153
+ "content": user_prompt
154
+ }
155
+ ],
156
+ temperature=1.63,
157
+ max_tokens=2240,
158
+ top_p=1,
159
+ frequency_penalty=1.73,
160
+ presence_penalty=1.76
161
+ )
162
+
163
+ return final_response.choices[0].text
164
+
165
+ def combined_interface(query, file, collection_name):
166
+ if query:
167
+ return retrieve_info(query)
168
+ elif file is not None and collection_name:
169
+ return embed_pdf(file, collection_name)
170
+ else:
171
+ return "Please enter a query or upload a PDF file."
172
+
173
+ iface = gr.Interface(
174
+ fn=combined_interface,
175
+ inputs=[
176
+ gr.Textbox(label="Query"),
177
+ gr.File(label="PDF File"),
178
+ gr.Textbox(label="Collection Name")
179
+ ],
180
+ outputs="text"
181
+ )
182
+
183
+ iface.launch()