Reality123b commited on
Commit
2008367
·
verified ·
1 Parent(s): e41b8bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -190
app.py CHANGED
@@ -1,172 +1,133 @@
1
- from fastapi import FastAPI
2
- from fastapi.responses import StreamingResponse, HTMLResponse
3
  from pydantic import BaseModel
4
  from transformers import pipeline, TextStreamer
5
- import asyncio
6
- import httpx
7
- import time
8
  import queue
9
  import threading
10
- import random
11
  import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # =========================
14
- # CONFIG
15
- # =========================
16
- UPDATE_INTERVAL = 60 # seconds between KG updates
17
- MAX_KG_SIZE = 50 # limit stored KG nodes to avoid memory bloat
18
-
19
- # =========================
20
- # MODELS
21
- # =========================
22
- # Main generator
23
- generator = pipeline(
24
- "text-generation",
25
- model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
26
- device="cpu"
27
- )
28
-
29
- # Query + summarization model (SmolLM2 instruct)
30
- query_generator = pipeline(
31
- "text-generation",
32
- model="HuggingFaceTB/SmolLM2-360M-Instruct",
33
- device="cpu"
34
- )
35
-
36
- summarizer = query_generator # same model for now
37
-
38
- # =========================
39
- # KNOWLEDGE GRAPH
40
- # =========================
41
- knowledge_graph = {}
42
 
43
- # =========================
44
- # FASTAPI
45
- # =========================
46
- app = FastAPI()
47
 
 
 
 
48
  class ModelInput(BaseModel):
49
  prompt: str
50
- max_new_tokens: int = 64000
51
-
52
- # =========================
53
- # UTILS
54
- # =========================
55
- async def fetch_ddg_search(query: str):
56
- url = "https://api.duckduckgo.com/"
57
- params = {
58
- "q": query,
59
- "format": "json",
60
- "no_redirect": "1",
61
- "no_html": "1",
62
- "skip_disambig": "1"
63
- }
64
- async with httpx.AsyncClient() as client:
65
- resp = await client.get(url, params=params, timeout=15)
66
- return resp.json()
67
-
68
- def clean_ddg_text(ddg_json):
69
- abstract = ddg_json.get("AbstractText", "")
70
- related = ddg_json.get("RelatedTopics", [])
71
- related_texts = []
72
- for item in related:
73
- if isinstance(item, dict) and "Text" in item:
74
- related_texts.append(item["Text"])
75
- elif isinstance(item, dict) and "Topics" in item:
76
- for sub in item["Topics"]:
77
- if "Text" in sub:
78
- related_texts.append(sub["Text"])
79
- combined = (abstract + " " + " ".join(related_texts)).strip()
80
- combined = re.sub(r"\s+", " ", combined)
81
- if len(combined) > 1000:
82
- combined = combined[:1000] + "..."
83
- return combined
84
 
85
  def generate_dynamic_query():
 
86
  prompt = (
87
  "Generate a short, specific search query about technology, startups, AI, or science. "
88
  "Be creative, realistic, and output only the query with no extra words."
89
  )
90
  output = query_generator(
91
  prompt,
92
- max_new_tokens=32,
93
  truncation=True,
94
  do_sample=True,
95
- temperature=0.9
96
- )
97
- query = output[0]["generated_text"].strip().split("\n")[0]
 
 
 
98
  return query
99
 
100
- def summarize_text(text: str):
101
- prompt = f"Summarize this concisely:\n{text}\nSummary:"
102
- output = summarizer(
103
- prompt,
104
- max_new_tokens=256,
105
- truncation=True,
106
- do_sample=False
107
- )
108
- return output[0]["generated_text"].strip()
109
-
110
- def inject_relevant_kg(prompt: str):
111
- """Find relevant KG entries and inject into prompt."""
112
- if not knowledge_graph:
113
- return prompt
114
- best_match = None
115
- for key, node in knowledge_graph.items():
116
- if any(word.lower() in prompt.lower() for word in key.split()):
117
- best_match = node
118
- break
119
- if best_match:
120
- return f"{prompt}\n\nRelevant knowledge from memory:\n{best_match['summary']}"
121
- return prompt
122
-
123
- # =========================
124
- # BACKGROUND TASK
125
- # =========================
126
- async def update_knowledge_graph_periodically():
127
  while True:
128
  try:
129
  query = generate_dynamic_query()
130
- print(f"[KG Updater] Searching DDG for query: {query}")
131
- ddg_data = await fetch_ddg_search(query)
132
- cleaned = clean_ddg_text(ddg_data)
133
 
134
- if not cleaned or len(cleaned) < 50:
 
 
135
  print("[KG Updater] Too little info found, retrying next cycle...")
136
- else:
137
- summary = summarize_text(cleaned)
138
- knowledge_graph[query] = {
139
- "raw_text": cleaned,
140
- "summary": summary,
141
- "timestamp": time.time()
142
- }
143
- if len(knowledge_graph) > MAX_KG_SIZE:
144
- # remove oldest
145
- oldest_key = min(knowledge_graph, key=lambda k: knowledge_graph[k]['timestamp'])
146
- del knowledge_graph[oldest_key]
147
- print(f"[KG Updater] Knowledge graph updated for query: {query}")
148
 
149
  except Exception as e:
150
  print(f"[KG Updater] Error: {e}")
151
-
152
- await asyncio.sleep(UPDATE_INTERVAL)
153
-
154
- @app.on_event("startup")
155
- async def startup_event():
156
- asyncio.create_task(update_knowledge_graph_periodically())
157
-
158
- # =========================
159
- # STREAMING ENDPOINT
160
- # =========================
 
 
 
 
 
 
 
 
161
  @app.post("/generate/stream")
162
  async def generate_stream(input: ModelInput):
163
  q = queue.Queue()
164
 
165
  def run_generation():
166
  try:
167
- streamer = TextStreamer(generator.tokenizer, skip_prompt=True)
168
- def enqueue_token(token):
169
- q.put(token)
 
 
 
 
 
 
170
  streamer.put = enqueue_token
171
 
172
  enriched_prompt = inject_relevant_kg(input.prompt)
@@ -181,72 +142,33 @@ async def generate_stream(input: ModelInput):
181
  finally:
182
  q.put(None)
183
 
184
- thread = threading.Thread(target=run_generation)
185
- thread.start()
186
 
187
  async def event_generator():
188
- loop = asyncio.get_event_loop()
189
  while True:
190
- token = await loop.run_in_executor(None, q.get)
191
- if token is None:
192
  break
193
- yield token
194
 
195
  return StreamingResponse(event_generator(), media_type="text/plain")
196
 
197
- # =========================
198
- # VIEW KG
199
- # =========================
 
 
 
 
 
 
 
 
 
200
  @app.get("/knowledge")
201
  async def get_knowledge():
202
  return knowledge_graph
203
 
204
- # =========================
205
- # TEST CLIENT PAGE
206
- # =========================
207
- @app.get("/", response_class=HTMLResponse)
208
  async def root():
209
- return """
210
- <!DOCTYPE html>
211
- <html>
212
- <head><title>Xylaria Cognitive Worker</title></head>
213
- <body>
214
- <h2>Xylaria Cognitive Worker</h2>
215
- <textarea id="prompt" rows="4" cols="60">Explain how AI startups secure funding</textarea><br/>
216
- <button onclick="startStreaming()">Generate</button>
217
- <pre id="output" style="white-space: pre-wrap; background:#eee; padding:10px; border-radius:5px; max-height:400px; overflow:auto;"></pre>
218
- <h3>Knowledge Graph</h3>
219
- <pre id="kg" style="background:#ddd; padding:10px; max-height:300px; overflow:auto;"></pre>
220
-
221
- <script>
222
- async function startStreaming() {
223
- const prompt = document.getElementById("prompt").value;
224
- const output = document.getElementById("output");
225
- output.textContent = "";
226
- const response = await fetch("/generate/stream", {
227
- method: "POST",
228
- headers: { "Content-Type": "application/json" },
229
- body: JSON.stringify({ prompt: prompt, max_new_tokens: 64000 })
230
- });
231
- const reader = response.body.getReader();
232
- const decoder = new TextDecoder();
233
- while(true) {
234
- const {done, value} = await reader.read();
235
- if(done) break;
236
- const chunk = decoder.decode(value, {stream: true});
237
- output.textContent += chunk;
238
- output.scrollTop = output.scrollHeight;
239
- }
240
- }
241
- async function fetchKG() {
242
- const kgPre = document.getElementById("kg");
243
- const res = await fetch("/knowledge");
244
- const data = await res.json();
245
- kgPre.textContent = JSON.stringify(data, null, 2);
246
- }
247
- setInterval(fetchKG, 10000);
248
- window.onload = fetchKG;
249
- </script>
250
- </body>
251
- </html>
252
- """
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import StreamingResponse
3
  from pydantic import BaseModel
4
  from transformers import pipeline, TextStreamer
5
+ import torch
 
 
6
  import queue
7
  import threading
8
+ import time
9
  import re
10
+ from duckduckgo_search import DDGS
11
+
12
+ # ------------------------
13
+ # Config
14
+ # ------------------------
15
+ MAIN_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
16
+ QUERY_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
17
+ SUMMARY_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
18
+ DEVICE = "cpu" # set to 0 for GPU
19
+ KG_UPDATE_INTERVAL = 60 # seconds
20
+ MAX_NEW_TOKENS = 64000
21
+
22
+ # ------------------------
23
+ # API + Models Init
24
+ # ------------------------
25
+ app = FastAPI()
26
 
27
+ generator = pipeline("text-generation", model=MAIN_MODEL, device=DEVICE)
28
+ query_generator = pipeline("text-generation", model=QUERY_MODEL, device=DEVICE)
29
+ summarizer = pipeline("text-generation", model=SUMMARY_MODEL, device=DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ knowledge_graph = {}
 
 
 
32
 
33
+ # ------------------------
34
+ # Data Model
35
+ # ------------------------
36
  class ModelInput(BaseModel):
37
  prompt: str
38
+ max_new_tokens: int = MAX_NEW_TOKENS
39
+
40
+ # ------------------------
41
+ # KG Functions
42
+ # ------------------------
43
+ def clean_text(text):
44
+ return re.sub(r"\s+", " ", text).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def generate_dynamic_query():
47
+ """Generates a realistic short search query."""
48
  prompt = (
49
  "Generate a short, specific search query about technology, startups, AI, or science. "
50
  "Be creative, realistic, and output only the query with no extra words."
51
  )
52
  output = query_generator(
53
  prompt,
54
+ max_new_tokens=16,
55
  truncation=True,
56
  do_sample=True,
57
+ temperature=1.0,
58
+ top_p=0.9
59
+ )[0]["generated_text"].strip()
60
+ # Take only first line and remove 'Generate'
61
+ query = output.split("\n")[0]
62
+ query = re.sub(r"^Generate.*?:", "", query).strip()
63
  return query
64
 
65
+ def search_ddg(query):
66
+ with DDGS() as ddgs:
67
+ results = list(ddgs.text(query, max_results=5))
68
+ return " ".join([r.get("body", "") for r in results])
69
+
70
+ def summarize_text(text):
71
+ summary_prompt = f"Summarize this in 3 concise sentences:\n\n{text}"
72
+ return summarizer(summary_prompt, max_new_tokens=100, truncation=True)[0]["generated_text"].strip()
73
+
74
+ def kg_updater():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  while True:
76
  try:
77
  query = generate_dynamic_query()
78
+ if not query or len(query) < 3:
79
+ time.sleep(KG_UPDATE_INTERVAL)
80
+ continue
81
 
82
+ print(f"[KG Updater] Searching DDG for query: {query}")
83
+ raw_text = clean_text(search_ddg(query))
84
+ if len(raw_text) < 40:
85
  print("[KG Updater] Too little info found, retrying next cycle...")
86
+ time.sleep(KG_UPDATE_INTERVAL)
87
+ continue
88
+
89
+ summary = summarize_text(raw_text)
90
+ knowledge_graph[query] = {
91
+ "summary": summary,
92
+ "timestamp": time.time()
93
+ }
94
+ print(f"[KG Updater] Knowledge graph updated for query: {query}")
 
 
 
95
 
96
  except Exception as e:
97
  print(f"[KG Updater] Error: {e}")
98
+ time.sleep(KG_UPDATE_INTERVAL)
99
+
100
+ # Start KG updater thread
101
+ threading.Thread(target=kg_updater, daemon=True).start()
102
+
103
+ # ------------------------
104
+ # Prompt Injection
105
+ # ------------------------
106
+ def inject_relevant_kg(user_prompt):
107
+ # Simple keyword match for relevance
108
+ for query, data in knowledge_graph.items():
109
+ if any(word.lower() in user_prompt.lower() for word in query.split()):
110
+ return f"{user_prompt}\n\n[Relevant Info from Knowledge Graph]\n{data['summary']}\n"
111
+ return user_prompt
112
+
113
+ # ------------------------
114
+ # Streaming Generation
115
+ # ------------------------
116
  @app.post("/generate/stream")
117
  async def generate_stream(input: ModelInput):
118
  q = queue.Queue()
119
 
120
  def run_generation():
121
  try:
122
+ tokenizer = generator.tokenizer
123
+
124
+ def enqueue_token(token_ids):
125
+ if hasattr(token_ids, "tolist"): # tensor → list
126
+ token_ids = token_ids.tolist()
127
+ text = tokenizer.decode(token_ids, skip_special_tokens=True)
128
+ q.put(text)
129
+
130
+ streamer = TextStreamer(tokenizer, skip_prompt=True)
131
  streamer.put = enqueue_token
132
 
133
  enriched_prompt = inject_relevant_kg(input.prompt)
 
142
  finally:
143
  q.put(None)
144
 
145
+ threading.Thread(target=run_generation, daemon=True).start()
 
146
 
147
  async def event_generator():
 
148
  while True:
149
+ chunk = q.get()
150
+ if chunk is None:
151
  break
152
+ yield chunk
153
 
154
  return StreamingResponse(event_generator(), media_type="text/plain")
155
 
156
+ # ------------------------
157
+ # Endpoints
158
+ # ------------------------
159
+ @app.post("/generate")
160
+ async def generate_text(input: ModelInput):
161
+ try:
162
+ enriched_prompt = inject_relevant_kg(input.prompt)
163
+ response = generator(enriched_prompt, max_new_tokens=input.max_new_tokens, do_sample=False)[0]["generated_text"]
164
+ return {"generated_text": response}
165
+ except Exception as e:
166
+ raise HTTPException(status_code=500, detail=str(e))
167
+
168
  @app.get("/knowledge")
169
  async def get_knowledge():
170
  return knowledge_graph
171
 
172
+ @app.get("/")
 
 
 
173
  async def root():
174
+ return {"message": "Welcome to the Cognitive Swarm Worker API with Streaming + KG!"}