from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel from transformers import pipeline, TextStreamer import torch import re import threading import queue import time import random import duckduckgo_search from duckduckgo_search import DDGS # ------------------------ # Config # ------------------------ MAIN_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" QUERY_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct" SUMMARY_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct" DEVICE = 0 if torch.cuda.is_available() else "cpu" DEEPSEEK_MAX_TOKENS = 64000 SMOLLM_MAX_TOKENS = 4192 KG_UPDATE_INTERVAL = 60 # seconds knowledge_graph = {} # ------------------------ # API + Models Init # ------------------------ app = FastAPI() print("[Init] Loading models...") generator = pipeline("text-generation", model=MAIN_MODEL, device=DEVICE) query_generator = pipeline("text-generation", model=QUERY_MODEL, device=DEVICE) summarizer = pipeline("text-generation", model=SUMMARY_MODEL, device=DEVICE) print("[Init] Models loaded.") class ModelInput(BaseModel): prompt: str max_new_tokens: int = DEEPSEEK_MAX_TOKENS # ------------------------ # KG Functions # ------------------------ def generate_dynamic_query(): prompt = ( "Generate a short, specific search query about technology, startups, AI, or science. " "Be creative, realistic, and output only the query with no extra words." ) output = query_generator( prompt, max_new_tokens=SMOLLM_MAX_TOKENS, truncation=True, do_sample=True, temperature=1.0, top_p=0.9 )[0]["generated_text"].strip() query = output.split("\n")[0] query = re.sub(r"^Generate.*?:", "", query).strip() return query def summarize_text(text): summary_prompt = f"Summarize this in 3 concise sentences:\n\n{text}" return summarizer( summary_prompt, max_new_tokens=SMOLLM_MAX_TOKENS, truncation=True )[0]["generated_text"].strip() def search_ddg(query): with DDGS() as ddgs: results = list(ddgs.text(query, max_results=5)) combined = " ".join(r["body"] for r in results if "body" in r) return combined.strip() def kg_updater(): while True: try: query = generate_dynamic_query() print(f"[KG Updater] Searching DDG for query: {query}") raw_text = search_ddg(query) if len(raw_text) < 50: print("[KG Updater] Too little info found, retrying next cycle...") else: summary = summarize_text(raw_text) knowledge_graph[query] = summary print(f"[KG Updater] Knowledge graph updated for query: {query}") except Exception as e: print(f"[KG Updater ERROR] {e}") time.sleep(KG_UPDATE_INTERVAL) threading.Thread(target=kg_updater, daemon=True).start() def inject_relevant_kg(prompt): relevant_info = "" for k, v in knowledge_graph.items(): if any(word.lower() in prompt.lower() for word in k.split()): relevant_info += f"\n[KG:{k}] {v}" if relevant_info: return f"{prompt}\n\nRelevant background info:\n{relevant_info}" return prompt # ------------------------ # Streaming Generation # ------------------------ @app.post("/generate/stream") async def generate_stream(input: ModelInput): q = queue.Queue() def run_generation(): try: tokenizer = generator.tokenizer def enqueue_token(token_ids): if hasattr(token_ids, "tolist"): token_ids = token_ids.tolist() text = tokenizer.decode(token_ids, skip_special_tokens=True) q.put(text) streamer = TextStreamer(tokenizer, skip_prompt=True) streamer.put = enqueue_token # intercept tokens enriched_prompt = inject_relevant_kg(input.prompt) generator( enriched_prompt, max_new_tokens=min(input.max_new_tokens, DEEPSEEK_MAX_TOKENS), do_sample=False, streamer=streamer ) except Exception as e: q.put(f"[ERROR] {e}") finally: q.put(None) threading.Thread(target=run_generation, daemon=True).start() async def event_generator(): while True: token = q.get() if token is None: break yield token return StreamingResponse(event_generator(), media_type="text/plain") # ------------------------ # Non-stream endpoint # ------------------------ @app.post("/generate") async def generate_text(input: ModelInput): try: enriched_prompt = inject_relevant_kg(input.prompt) output = generator( enriched_prompt, max_new_tokens=min(input.max_new_tokens, DEEPSEEK_MAX_TOKENS), do_sample=False )[0]["generated_text"] return {"generated_text": output} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ------------------------ # KG endpoint # ------------------------ @app.get("/knowledge") async def get_knowledge(): return knowledge_graph # ------------------------ # Root endpoint # ------------------------ @app.get("/") async def root(): return {"message": "Welcome to the Streaming Model API with KG Updater!"}