makkzone commited on
Commit
62cc824
·
verified ·
1 Parent(s): 60645c3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.embeddings import HuggingFaceEmbeddings
3
+ from langchain.vectorstores import FAISS
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_community.llms import Ollama
6
+ from langgraph.graph import StateGraph, END
7
+ from langchain.chains import RetrievalQA
8
+ import requests
9
+ from typing import TypedDict, Annotated, List
10
+ from langchain_core.messages import HumanMessage, AIMessage
11
+ import operator
12
+
13
+ # Function to fetch GitHub repo data
14
+ def fetch_github_data(repo_url):
15
+ parts = repo_url.split('/')
16
+ owner, repo = parts[-2], parts[-1]
17
+
18
+ headers = {'Accept': 'application/vnd.github.v3+json'}
19
+ base_url = 'https://api.github.com'
20
+
21
+ content = ""
22
+ repo_response = requests.get(f"{base_url}/repos/{owner}/{repo}", headers=headers)
23
+ if repo_response.status_code == 200:
24
+ repo_data = repo_response.json()
25
+ content += f"Description: {repo_data.get('description', '')}\n"
26
+
27
+ readme_response = requests.get(f"{base_url}/repos/{owner}/{repo}/readme", headers=headers)
28
+ if readme_response.status_code == 200:
29
+ import base64
30
+ readme_data = readme_response.json()
31
+ content += base64.b64decode(readme_data['content']).decode('utf-8') + "\n"
32
+
33
+ return content
34
+
35
+ # Function to create vector store
36
+ def create_vector_store(text_data):
37
+ text_splitter = RecursiveCharacterTextSplitter(
38
+ chunk_size=1000,
39
+ chunk_overlap=200
40
+ )
41
+ chunks = text_splitter.split_text(text_data)
42
+
43
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
44
+ vector_store = FAISS.from_texts(chunks, embeddings)
45
+ return vector_store
46
+
47
+ # Define the state for LangGraph
48
+ class GraphState(TypedDict):
49
+ messages: Annotated[List[HumanMessage | AIMessage], operator.add]
50
+ generation: str
51
+ search_count: int
52
+
53
+ # Node to perform initial/additional vector store search
54
+ def search_vector_store(state: GraphState):
55
+ llm = st.session_state.llm
56
+ vector_store = st.session_state.vector_store
57
+ question = state["messages"][0].content # Original question
58
+ current_generation = state["generation"]
59
+ search_count = state["search_count"] + 1
60
+
61
+ # Modify query slightly for additional searches
62
+ if search_count > 1:
63
+ query = f"{question} (additional context for: {current_generation})"
64
+ else:
65
+ query = question
66
+
67
+ retriever = vector_store.as_retriever()
68
+ qa_chain = RetrievalQA.from_chain_type(
69
+ llm=llm,
70
+ chain_type="stuff",
71
+ retriever=retriever
72
+ )
73
+ response = qa_chain.run(query)
74
+
75
+ # Append new info to existing generation
76
+ new_generation = f"{current_generation}\nAdditional Info: {response}" if current_generation else response
77
+
78
+ return {
79
+ "messages": [AIMessage(content=new_generation)],
80
+ "generation": new_generation,
81
+ "search_count": search_count
82
+ }
83
+
84
+ # Node to evaluate sufficiency of the answer
85
+ def evaluate_sufficiency(state: GraphState):
86
+ llm = st.session_state.llm
87
+ question = state["messages"][0].content # Original question
88
+ current_generation = state["generation"]
89
+
90
+ prompt = (
91
+ f"Given the question '{question}' and the current information:\n'{current_generation}'\n"
92
+ f"Is this sufficient to fully answer the question? Respond with 'Yes' or 'No'."
93
+ )
94
+ decision = llm.invoke(prompt).strip()
95
+
96
+ return {
97
+ "messages": [AIMessage(content=f"Sufficiency check: {decision}")],
98
+ "generation": current_generation,
99
+ "search_count": state["search_count"]
100
+ }
101
+
102
+ # Node to finalize the answer
103
+ def finalize_answer(state: GraphState):
104
+ llm = st.session_state.llm
105
+ current_info = state["generation"]
106
+ question = state["messages"][0].content # Original question
107
+ prompt = (
108
+ f"Given the question '{question}' and the current information:\n'{current_info}'\n"
109
+ f"Answer the question as you are answering for the first time"
110
+ )
111
+ final_answer = llm.invoke(prompt).strip()
112
+ return {
113
+ "messages": [AIMessage(content=f"Final Answer: {final_answer}")],
114
+ "generation": final_answer,
115
+ "search_count": state["search_count"]
116
+ }
117
+
118
+ # Function to decide next step
119
+ def route_next_step(state: GraphState):
120
+ last_message = state["messages"][-1].content
121
+ search_count = state["search_count"]
122
+
123
+ if "Sufficiency check: Yes" in last_message:
124
+ return "finalize_answer"
125
+ elif search_count >= 5:
126
+ return "finalize_answer" # Max 5 iterations
127
+ else:
128
+ return "search_vector_store"
129
+
130
+ # Build the LangGraph workflow
131
+ def build_graph():
132
+ workflow = StateGraph(GraphState)
133
+
134
+ workflow.add_node("search_vector_store", search_vector_store)
135
+ workflow.add_node("evaluate_sufficiency", evaluate_sufficiency)
136
+ workflow.add_node("finalize_answer", finalize_answer)
137
+
138
+ workflow.set_entry_point("search_vector_store")
139
+ workflow.add_edge("search_vector_store", "evaluate_sufficiency")
140
+ workflow.add_conditional_edges(
141
+ "evaluate_sufficiency",
142
+ route_next_step,
143
+ {
144
+ "search_vector_store": "search_vector_store",
145
+ "finalize_answer": "finalize_answer"
146
+ }
147
+ )
148
+ workflow.add_edge("finalize_answer", END)
149
+
150
+ return workflow.compile()
151
+
152
+ # Streamlit app
153
+ def main():
154
+ st.title("Project Resilience Q&A Assistant")
155
+ st.write("Ask anything about Project Resilience - answers always come from repo data!")
156
+
157
+ # Hardcoded GitHub URL
158
+ github_url = 'https://github.com/Project-Resilience/platform'
159
+ repo_data = fetch_github_data(github_url)
160
+
161
+ # Initialize session state
162
+ if 'vector_store' not in st.session_state:
163
+ st.session_state.vector_store = create_vector_store(repo_data)
164
+ st.session_state.llm = Ollama(model="llama3.2", temperature=0.7)
165
+ st.session_state.graph = build_graph()
166
+
167
+ # Question input
168
+ question = st.text_input("Ask a question about the project# Project Resilience")
169
+
170
+ # Get and display answer
171
+ if question and st.session_state.graph:
172
+ with st.spinner("Generating answer..."):
173
+ initial_state = {
174
+ "messages": [HumanMessage(content=question)],
175
+ "generation": "",
176
+ "search_count": 0
177
+ }
178
+ result = st.session_state.graph.invoke(initial_state)
179
+ final_answer = result["generation"]
180
+ st.write("**Answer:**")
181
+ st.write(final_answer)
182
+
183
+ # Sidebar with additional info
184
+ st.sidebar.header("Project Resilience Assistant")
185
+ st.sidebar.write("""
186
+ Project Resilience's platform for decision makers, data scientists and the public.
187
+
188
+ Project Resilience, initiated under the Global Initiative on AI and Data Commons, is a collaborative effort to build a public AI utility that could inform and help address global decision-augmentation challenges.
189
+
190
+ The project empowers a global community of innovators, thought leaders, and the public to enhance and use a shared collection of data and AI tools, improving preparedness, intervention, and response to environmental, health, information, or economic threats in our communities. It also supports broader efforts toward achieving the Sustainable Development Goals (SDGs).
191
+ """)
192
+
193
+ if __name__ == "__main__":
194
+ main()