Reality123b commited on
Commit
130360f
·
verified ·
1 Parent(s): 2008367

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -58
app.py CHANGED
@@ -3,10 +3,12 @@ from fastapi.responses import StreamingResponse
3
  from pydantic import BaseModel
4
  from transformers import pipeline, TextStreamer
5
  import torch
6
- import queue
7
  import threading
 
8
  import time
9
- import re
 
10
  from duckduckgo_search import DDGS
11
 
12
  # ------------------------
@@ -15,100 +17,89 @@ from duckduckgo_search import DDGS
15
  MAIN_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
16
  QUERY_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
17
  SUMMARY_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
18
- DEVICE = "cpu" # set to 0 for GPU
 
 
 
 
19
  KG_UPDATE_INTERVAL = 60 # seconds
20
- MAX_NEW_TOKENS = 64000
21
 
22
  # ------------------------
23
  # API + Models Init
24
  # ------------------------
25
  app = FastAPI()
26
 
 
27
  generator = pipeline("text-generation", model=MAIN_MODEL, device=DEVICE)
28
  query_generator = pipeline("text-generation", model=QUERY_MODEL, device=DEVICE)
29
  summarizer = pipeline("text-generation", model=SUMMARY_MODEL, device=DEVICE)
 
30
 
31
- knowledge_graph = {}
32
-
33
- # ------------------------
34
- # Data Model
35
- # ------------------------
36
  class ModelInput(BaseModel):
37
  prompt: str
38
- max_new_tokens: int = MAX_NEW_TOKENS
39
 
40
  # ------------------------
41
  # KG Functions
42
  # ------------------------
43
- def clean_text(text):
44
- return re.sub(r"\s+", " ", text).strip()
45
-
46
  def generate_dynamic_query():
47
- """Generates a realistic short search query."""
48
  prompt = (
49
  "Generate a short, specific search query about technology, startups, AI, or science. "
50
  "Be creative, realistic, and output only the query with no extra words."
51
  )
52
  output = query_generator(
53
  prompt,
54
- max_new_tokens=16,
55
  truncation=True,
56
  do_sample=True,
57
  temperature=1.0,
58
  top_p=0.9
59
  )[0]["generated_text"].strip()
60
- # Take only first line and remove 'Generate'
61
  query = output.split("\n")[0]
62
  query = re.sub(r"^Generate.*?:", "", query).strip()
63
  return query
64
 
 
 
 
 
 
 
 
 
65
  def search_ddg(query):
66
  with DDGS() as ddgs:
67
  results = list(ddgs.text(query, max_results=5))
68
- return " ".join([r.get("body", "") for r in results])
69
-
70
- def summarize_text(text):
71
- summary_prompt = f"Summarize this in 3 concise sentences:\n\n{text}"
72
- return summarizer(summary_prompt, max_new_tokens=100, truncation=True)[0]["generated_text"].strip()
73
 
74
  def kg_updater():
75
  while True:
76
  try:
77
  query = generate_dynamic_query()
78
- if not query or len(query) < 3:
79
- time.sleep(KG_UPDATE_INTERVAL)
80
- continue
81
-
82
  print(f"[KG Updater] Searching DDG for query: {query}")
83
- raw_text = clean_text(search_ddg(query))
84
- if len(raw_text) < 40:
85
  print("[KG Updater] Too little info found, retrying next cycle...")
86
- time.sleep(KG_UPDATE_INTERVAL)
87
- continue
88
-
89
- summary = summarize_text(raw_text)
90
- knowledge_graph[query] = {
91
- "summary": summary,
92
- "timestamp": time.time()
93
- }
94
- print(f"[KG Updater] Knowledge graph updated for query: {query}")
95
-
96
  except Exception as e:
97
- print(f"[KG Updater] Error: {e}")
98
  time.sleep(KG_UPDATE_INTERVAL)
99
 
100
- # Start KG updater thread
101
  threading.Thread(target=kg_updater, daemon=True).start()
102
 
103
- # ------------------------
104
- # Prompt Injection
105
- # ------------------------
106
- def inject_relevant_kg(user_prompt):
107
- # Simple keyword match for relevance
108
- for query, data in knowledge_graph.items():
109
- if any(word.lower() in user_prompt.lower() for word in query.split()):
110
- return f"{user_prompt}\n\n[Relevant Info from Knowledge Graph]\n{data['summary']}\n"
111
- return user_prompt
112
 
113
  # ------------------------
114
  # Streaming Generation
@@ -122,18 +113,18 @@ async def generate_stream(input: ModelInput):
122
  tokenizer = generator.tokenizer
123
 
124
  def enqueue_token(token_ids):
125
- if hasattr(token_ids, "tolist"): # tensor → list
126
  token_ids = token_ids.tolist()
127
  text = tokenizer.decode(token_ids, skip_special_tokens=True)
128
  q.put(text)
129
 
130
  streamer = TextStreamer(tokenizer, skip_prompt=True)
131
- streamer.put = enqueue_token
132
 
133
  enriched_prompt = inject_relevant_kg(input.prompt)
134
  generator(
135
  enriched_prompt,
136
- max_new_tokens=input.max_new_tokens,
137
  do_sample=False,
138
  streamer=streamer
139
  )
@@ -146,29 +137,39 @@ async def generate_stream(input: ModelInput):
146
 
147
  async def event_generator():
148
  while True:
149
- chunk = q.get()
150
- if chunk is None:
151
  break
152
- yield chunk
153
 
154
  return StreamingResponse(event_generator(), media_type="text/plain")
155
 
156
  # ------------------------
157
- # Endpoints
158
  # ------------------------
159
  @app.post("/generate")
160
  async def generate_text(input: ModelInput):
161
  try:
162
  enriched_prompt = inject_relevant_kg(input.prompt)
163
- response = generator(enriched_prompt, max_new_tokens=input.max_new_tokens, do_sample=False)[0]["generated_text"]
164
- return {"generated_text": response}
 
 
 
 
165
  except Exception as e:
166
  raise HTTPException(status_code=500, detail=str(e))
167
 
 
 
 
168
  @app.get("/knowledge")
169
  async def get_knowledge():
170
  return knowledge_graph
171
 
 
 
 
172
  @app.get("/")
173
  async def root():
174
- return {"message": "Welcome to the Cognitive Swarm Worker API with Streaming + KG!"}
 
3
  from pydantic import BaseModel
4
  from transformers import pipeline, TextStreamer
5
  import torch
6
+ import re
7
  import threading
8
+ import queue
9
  import time
10
+ import random
11
+ import duckduckgo_search
12
  from duckduckgo_search import DDGS
13
 
14
  # ------------------------
 
17
  MAIN_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
18
  QUERY_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
19
  SUMMARY_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
20
+ DEVICE = 0 if torch.cuda.is_available() else "cpu"
21
+
22
+ DEEPSEEK_MAX_TOKENS = 64000
23
+ SMOLLM_MAX_TOKENS = 4192
24
+
25
  KG_UPDATE_INTERVAL = 60 # seconds
26
+ knowledge_graph = {}
27
 
28
  # ------------------------
29
  # API + Models Init
30
  # ------------------------
31
  app = FastAPI()
32
 
33
+ print("[Init] Loading models...")
34
  generator = pipeline("text-generation", model=MAIN_MODEL, device=DEVICE)
35
  query_generator = pipeline("text-generation", model=QUERY_MODEL, device=DEVICE)
36
  summarizer = pipeline("text-generation", model=SUMMARY_MODEL, device=DEVICE)
37
+ print("[Init] Models loaded.")
38
 
 
 
 
 
 
39
  class ModelInput(BaseModel):
40
  prompt: str
41
+ max_new_tokens: int = DEEPSEEK_MAX_TOKENS
42
 
43
  # ------------------------
44
  # KG Functions
45
  # ------------------------
 
 
 
46
  def generate_dynamic_query():
 
47
  prompt = (
48
  "Generate a short, specific search query about technology, startups, AI, or science. "
49
  "Be creative, realistic, and output only the query with no extra words."
50
  )
51
  output = query_generator(
52
  prompt,
53
+ max_new_tokens=SMOLLM_MAX_TOKENS,
54
  truncation=True,
55
  do_sample=True,
56
  temperature=1.0,
57
  top_p=0.9
58
  )[0]["generated_text"].strip()
 
59
  query = output.split("\n")[0]
60
  query = re.sub(r"^Generate.*?:", "", query).strip()
61
  return query
62
 
63
+ def summarize_text(text):
64
+ summary_prompt = f"Summarize this in 3 concise sentences:\n\n{text}"
65
+ return summarizer(
66
+ summary_prompt,
67
+ max_new_tokens=SMOLLM_MAX_TOKENS,
68
+ truncation=True
69
+ )[0]["generated_text"].strip()
70
+
71
  def search_ddg(query):
72
  with DDGS() as ddgs:
73
  results = list(ddgs.text(query, max_results=5))
74
+ combined = " ".join(r["body"] for r in results if "body" in r)
75
+ return combined.strip()
 
 
 
76
 
77
  def kg_updater():
78
  while True:
79
  try:
80
  query = generate_dynamic_query()
 
 
 
 
81
  print(f"[KG Updater] Searching DDG for query: {query}")
82
+ raw_text = search_ddg(query)
83
+ if len(raw_text) < 50:
84
  print("[KG Updater] Too little info found, retrying next cycle...")
85
+ else:
86
+ summary = summarize_text(raw_text)
87
+ knowledge_graph[query] = summary
88
+ print(f"[KG Updater] Knowledge graph updated for query: {query}")
 
 
 
 
 
 
89
  except Exception as e:
90
+ print(f"[KG Updater ERROR] {e}")
91
  time.sleep(KG_UPDATE_INTERVAL)
92
 
 
93
  threading.Thread(target=kg_updater, daemon=True).start()
94
 
95
+ def inject_relevant_kg(prompt):
96
+ relevant_info = ""
97
+ for k, v in knowledge_graph.items():
98
+ if any(word.lower() in prompt.lower() for word in k.split()):
99
+ relevant_info += f"\n[KG:{k}] {v}"
100
+ if relevant_info:
101
+ return f"{prompt}\n\nRelevant background info:\n{relevant_info}"
102
+ return prompt
 
103
 
104
  # ------------------------
105
  # Streaming Generation
 
113
  tokenizer = generator.tokenizer
114
 
115
  def enqueue_token(token_ids):
116
+ if hasattr(token_ids, "tolist"):
117
  token_ids = token_ids.tolist()
118
  text = tokenizer.decode(token_ids, skip_special_tokens=True)
119
  q.put(text)
120
 
121
  streamer = TextStreamer(tokenizer, skip_prompt=True)
122
+ streamer.put = enqueue_token # intercept tokens
123
 
124
  enriched_prompt = inject_relevant_kg(input.prompt)
125
  generator(
126
  enriched_prompt,
127
+ max_new_tokens=min(input.max_new_tokens, DEEPSEEK_MAX_TOKENS),
128
  do_sample=False,
129
  streamer=streamer
130
  )
 
137
 
138
  async def event_generator():
139
  while True:
140
+ token = q.get()
141
+ if token is None:
142
  break
143
+ yield token
144
 
145
  return StreamingResponse(event_generator(), media_type="text/plain")
146
 
147
  # ------------------------
148
+ # Non-stream endpoint
149
  # ------------------------
150
  @app.post("/generate")
151
  async def generate_text(input: ModelInput):
152
  try:
153
  enriched_prompt = inject_relevant_kg(input.prompt)
154
+ output = generator(
155
+ enriched_prompt,
156
+ max_new_tokens=min(input.max_new_tokens, DEEPSEEK_MAX_TOKENS),
157
+ do_sample=False
158
+ )[0]["generated_text"]
159
+ return {"generated_text": output}
160
  except Exception as e:
161
  raise HTTPException(status_code=500, detail=str(e))
162
 
163
+ # ------------------------
164
+ # KG endpoint
165
+ # ------------------------
166
  @app.get("/knowledge")
167
  async def get_knowledge():
168
  return knowledge_graph
169
 
170
+ # ------------------------
171
+ # Root endpoint
172
+ # ------------------------
173
  @app.get("/")
174
  async def root():
175
+ return {"message": "Welcome to the Streaming Model API with KG Updater!"}