File size: 5,443 Bytes
2008367 6e0397b e41b8bc 2008367 130360f 2145ed0 130360f 2008367 130360f 2008367 130360f 2008367 130360f 2008367 e41b8bc 130360f 2008367 130360f 6e0397b e41b8bc 130360f 2008367 e41b8bc 8a72144 e41b8bc 8a72144 e41b8bc 130360f e41b8bc 2008367 e41b8bc 8a72144 130360f 2008367 130360f 2008367 2145ed0 e41b8bc 2008367 130360f e41b8bc 130360f 2145ed0 130360f 2008367 130360f 2008367 2145ed0 e41b8bc 2145ed0 2008367 130360f 2008367 130360f e41b8bc 2145ed0 e41b8bc 130360f 2145ed0 1ea4540 e41b8bc 2008367 2145ed0 130360f 2145ed0 130360f e41b8bc 2145ed0 2008367 130360f 2008367 130360f 2008367 130360f 2145ed0 130360f 2008367 6e0397b 130360f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import pipeline, TextStreamer
import torch
import re
import threading
import queue
import time
import random
import duckduckgo_search
from duckduckgo_search import DDGS
# ------------------------
# Config
# ------------------------
MAIN_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
QUERY_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
SUMMARY_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
DEVICE = 0 if torch.cuda.is_available() else "cpu"
DEEPSEEK_MAX_TOKENS = 64000
SMOLLM_MAX_TOKENS = 4192
KG_UPDATE_INTERVAL = 60 # seconds
knowledge_graph = {}
# ------------------------
# API + Models Init
# ------------------------
app = FastAPI()
print("[Init] Loading models...")
generator = pipeline("text-generation", model=MAIN_MODEL, device=DEVICE)
query_generator = pipeline("text-generation", model=QUERY_MODEL, device=DEVICE)
summarizer = pipeline("text-generation", model=SUMMARY_MODEL, device=DEVICE)
print("[Init] Models loaded.")
class ModelInput(BaseModel):
prompt: str
max_new_tokens: int = DEEPSEEK_MAX_TOKENS
# ------------------------
# KG Functions
# ------------------------
def generate_dynamic_query():
prompt = (
"Generate a short, specific search query about technology, startups, AI, or science. "
"Be creative, realistic, and output only the query with no extra words."
)
output = query_generator(
prompt,
max_new_tokens=SMOLLM_MAX_TOKENS,
truncation=True,
do_sample=True,
temperature=1.0,
top_p=0.9
)[0]["generated_text"].strip()
query = output.split("\n")[0]
query = re.sub(r"^Generate.*?:", "", query).strip()
return query
def summarize_text(text):
summary_prompt = f"Summarize this in 3 concise sentences:\n\n{text}"
return summarizer(
summary_prompt,
max_new_tokens=SMOLLM_MAX_TOKENS,
truncation=True
)[0]["generated_text"].strip()
def search_ddg(query):
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=5))
combined = " ".join(r["body"] for r in results if "body" in r)
return combined.strip()
def kg_updater():
while True:
try:
query = generate_dynamic_query()
print(f"[KG Updater] Searching DDG for query: {query}")
raw_text = search_ddg(query)
if len(raw_text) < 50:
print("[KG Updater] Too little info found, retrying next cycle...")
else:
summary = summarize_text(raw_text)
knowledge_graph[query] = summary
print(f"[KG Updater] Knowledge graph updated for query: {query}")
except Exception as e:
print(f"[KG Updater ERROR] {e}")
time.sleep(KG_UPDATE_INTERVAL)
threading.Thread(target=kg_updater, daemon=True).start()
def inject_relevant_kg(prompt):
relevant_info = ""
for k, v in knowledge_graph.items():
if any(word.lower() in prompt.lower() for word in k.split()):
relevant_info += f"\n[KG:{k}] {v}"
if relevant_info:
return f"{prompt}\n\nRelevant background info:\n{relevant_info}"
return prompt
# ------------------------
# Streaming Generation
# ------------------------
@app.post("/generate/stream")
async def generate_stream(input: ModelInput):
q = queue.Queue()
def run_generation():
try:
tokenizer = generator.tokenizer
def enqueue_token(token_ids):
if hasattr(token_ids, "tolist"):
token_ids = token_ids.tolist()
text = tokenizer.decode(token_ids, skip_special_tokens=True)
q.put(text)
streamer = TextStreamer(tokenizer, skip_prompt=True)
streamer.put = enqueue_token # intercept tokens
enriched_prompt = inject_relevant_kg(input.prompt)
generator(
enriched_prompt,
max_new_tokens=min(input.max_new_tokens, DEEPSEEK_MAX_TOKENS),
do_sample=False,
streamer=streamer
)
except Exception as e:
q.put(f"[ERROR] {e}")
finally:
q.put(None)
threading.Thread(target=run_generation, daemon=True).start()
async def event_generator():
while True:
token = q.get()
if token is None:
break
yield token
return StreamingResponse(event_generator(), media_type="text/plain")
# ------------------------
# Non-stream endpoint
# ------------------------
@app.post("/generate")
async def generate_text(input: ModelInput):
try:
enriched_prompt = inject_relevant_kg(input.prompt)
output = generator(
enriched_prompt,
max_new_tokens=min(input.max_new_tokens, DEEPSEEK_MAX_TOKENS),
do_sample=False
)[0]["generated_text"]
return {"generated_text": output}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ------------------------
# KG endpoint
# ------------------------
@app.get("/knowledge")
async def get_knowledge():
return knowledge_graph
# ------------------------
# Root endpoint
# ------------------------
@app.get("/")
async def root():
return {"message": "Welcome to the Streaming Model API with KG Updater!"}
|