|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.responses import StreamingResponse |
|
|
from pydantic import BaseModel |
|
|
from transformers import pipeline, TextStreamer |
|
|
import torch |
|
|
import re |
|
|
import threading |
|
|
import queue |
|
|
import time |
|
|
import random |
|
|
import duckduckgo_search |
|
|
from duckduckgo_search import DDGS |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAIN_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" |
|
|
QUERY_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct" |
|
|
SUMMARY_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct" |
|
|
DEVICE = 0 if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
DEEPSEEK_MAX_TOKENS = 64000 |
|
|
SMOLLM_MAX_TOKENS = 4192 |
|
|
|
|
|
KG_UPDATE_INTERVAL = 60 |
|
|
knowledge_graph = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
print("[Init] Loading models...") |
|
|
generator = pipeline("text-generation", model=MAIN_MODEL, device=DEVICE) |
|
|
query_generator = pipeline("text-generation", model=QUERY_MODEL, device=DEVICE) |
|
|
summarizer = pipeline("text-generation", model=SUMMARY_MODEL, device=DEVICE) |
|
|
print("[Init] Models loaded.") |
|
|
|
|
|
class ModelInput(BaseModel): |
|
|
prompt: str |
|
|
max_new_tokens: int = DEEPSEEK_MAX_TOKENS |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_dynamic_query(): |
|
|
prompt = ( |
|
|
"Generate a short, specific search query about technology, startups, AI, or science. " |
|
|
"Be creative, realistic, and output only the query with no extra words." |
|
|
) |
|
|
output = query_generator( |
|
|
prompt, |
|
|
max_new_tokens=SMOLLM_MAX_TOKENS, |
|
|
truncation=True, |
|
|
do_sample=True, |
|
|
temperature=1.0, |
|
|
top_p=0.9 |
|
|
)[0]["generated_text"].strip() |
|
|
query = output.split("\n")[0] |
|
|
query = re.sub(r"^Generate.*?:", "", query).strip() |
|
|
return query |
|
|
|
|
|
def summarize_text(text): |
|
|
summary_prompt = f"Summarize this in 3 concise sentences:\n\n{text}" |
|
|
return summarizer( |
|
|
summary_prompt, |
|
|
max_new_tokens=SMOLLM_MAX_TOKENS, |
|
|
truncation=True |
|
|
)[0]["generated_text"].strip() |
|
|
|
|
|
def search_ddg(query): |
|
|
with DDGS() as ddgs: |
|
|
results = list(ddgs.text(query, max_results=5)) |
|
|
combined = " ".join(r["body"] for r in results if "body" in r) |
|
|
return combined.strip() |
|
|
|
|
|
def kg_updater(): |
|
|
while True: |
|
|
try: |
|
|
query = generate_dynamic_query() |
|
|
print(f"[KG Updater] Searching DDG for query: {query}") |
|
|
raw_text = search_ddg(query) |
|
|
if len(raw_text) < 50: |
|
|
print("[KG Updater] Too little info found, retrying next cycle...") |
|
|
else: |
|
|
summary = summarize_text(raw_text) |
|
|
knowledge_graph[query] = summary |
|
|
print(f"[KG Updater] Knowledge graph updated for query: {query}") |
|
|
except Exception as e: |
|
|
print(f"[KG Updater ERROR] {e}") |
|
|
time.sleep(KG_UPDATE_INTERVAL) |
|
|
|
|
|
threading.Thread(target=kg_updater, daemon=True).start() |
|
|
|
|
|
def inject_relevant_kg(prompt): |
|
|
relevant_info = "" |
|
|
for k, v in knowledge_graph.items(): |
|
|
if any(word.lower() in prompt.lower() for word in k.split()): |
|
|
relevant_info += f"\n[KG:{k}] {v}" |
|
|
if relevant_info: |
|
|
return f"{prompt}\n\nRelevant background info:\n{relevant_info}" |
|
|
return prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/generate/stream") |
|
|
async def generate_stream(input: ModelInput): |
|
|
q = queue.Queue() |
|
|
|
|
|
def run_generation(): |
|
|
try: |
|
|
tokenizer = generator.tokenizer |
|
|
|
|
|
def enqueue_token(token_ids): |
|
|
if hasattr(token_ids, "tolist"): |
|
|
token_ids = token_ids.tolist() |
|
|
text = tokenizer.decode(token_ids, skip_special_tokens=True) |
|
|
q.put(text) |
|
|
|
|
|
streamer = TextStreamer(tokenizer, skip_prompt=True) |
|
|
streamer.put = enqueue_token |
|
|
|
|
|
enriched_prompt = inject_relevant_kg(input.prompt) |
|
|
generator( |
|
|
enriched_prompt, |
|
|
max_new_tokens=min(input.max_new_tokens, DEEPSEEK_MAX_TOKENS), |
|
|
do_sample=False, |
|
|
streamer=streamer |
|
|
) |
|
|
except Exception as e: |
|
|
q.put(f"[ERROR] {e}") |
|
|
finally: |
|
|
q.put(None) |
|
|
|
|
|
threading.Thread(target=run_generation, daemon=True).start() |
|
|
|
|
|
async def event_generator(): |
|
|
while True: |
|
|
token = q.get() |
|
|
if token is None: |
|
|
break |
|
|
yield token |
|
|
|
|
|
return StreamingResponse(event_generator(), media_type="text/plain") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/generate") |
|
|
async def generate_text(input: ModelInput): |
|
|
try: |
|
|
enriched_prompt = inject_relevant_kg(input.prompt) |
|
|
output = generator( |
|
|
enriched_prompt, |
|
|
max_new_tokens=min(input.max_new_tokens, DEEPSEEK_MAX_TOKENS), |
|
|
do_sample=False |
|
|
)[0]["generated_text"] |
|
|
return {"generated_text": output} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/knowledge") |
|
|
async def get_knowledge(): |
|
|
return knowledge_graph |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return {"message": "Welcome to the Streaming Model API with KG Updater!"} |
|
|
|