ap3 / app.py
Reality123b's picture
Update app.py
130360f verified
raw
history blame
5.44 kB
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import pipeline, TextStreamer
import torch
import re
import threading
import queue
import time
import random
import duckduckgo_search
from duckduckgo_search import DDGS
# ------------------------
# Config
# ------------------------
MAIN_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
QUERY_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
SUMMARY_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
DEVICE = 0 if torch.cuda.is_available() else "cpu"
DEEPSEEK_MAX_TOKENS = 64000
SMOLLM_MAX_TOKENS = 4192
KG_UPDATE_INTERVAL = 60 # seconds
knowledge_graph = {}
# ------------------------
# API + Models Init
# ------------------------
app = FastAPI()
print("[Init] Loading models...")
generator = pipeline("text-generation", model=MAIN_MODEL, device=DEVICE)
query_generator = pipeline("text-generation", model=QUERY_MODEL, device=DEVICE)
summarizer = pipeline("text-generation", model=SUMMARY_MODEL, device=DEVICE)
print("[Init] Models loaded.")
class ModelInput(BaseModel):
prompt: str
max_new_tokens: int = DEEPSEEK_MAX_TOKENS
# ------------------------
# KG Functions
# ------------------------
def generate_dynamic_query():
prompt = (
"Generate a short, specific search query about technology, startups, AI, or science. "
"Be creative, realistic, and output only the query with no extra words."
)
output = query_generator(
prompt,
max_new_tokens=SMOLLM_MAX_TOKENS,
truncation=True,
do_sample=True,
temperature=1.0,
top_p=0.9
)[0]["generated_text"].strip()
query = output.split("\n")[0]
query = re.sub(r"^Generate.*?:", "", query).strip()
return query
def summarize_text(text):
summary_prompt = f"Summarize this in 3 concise sentences:\n\n{text}"
return summarizer(
summary_prompt,
max_new_tokens=SMOLLM_MAX_TOKENS,
truncation=True
)[0]["generated_text"].strip()
def search_ddg(query):
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=5))
combined = " ".join(r["body"] for r in results if "body" in r)
return combined.strip()
def kg_updater():
while True:
try:
query = generate_dynamic_query()
print(f"[KG Updater] Searching DDG for query: {query}")
raw_text = search_ddg(query)
if len(raw_text) < 50:
print("[KG Updater] Too little info found, retrying next cycle...")
else:
summary = summarize_text(raw_text)
knowledge_graph[query] = summary
print(f"[KG Updater] Knowledge graph updated for query: {query}")
except Exception as e:
print(f"[KG Updater ERROR] {e}")
time.sleep(KG_UPDATE_INTERVAL)
threading.Thread(target=kg_updater, daemon=True).start()
def inject_relevant_kg(prompt):
relevant_info = ""
for k, v in knowledge_graph.items():
if any(word.lower() in prompt.lower() for word in k.split()):
relevant_info += f"\n[KG:{k}] {v}"
if relevant_info:
return f"{prompt}\n\nRelevant background info:\n{relevant_info}"
return prompt
# ------------------------
# Streaming Generation
# ------------------------
@app.post("/generate/stream")
async def generate_stream(input: ModelInput):
q = queue.Queue()
def run_generation():
try:
tokenizer = generator.tokenizer
def enqueue_token(token_ids):
if hasattr(token_ids, "tolist"):
token_ids = token_ids.tolist()
text = tokenizer.decode(token_ids, skip_special_tokens=True)
q.put(text)
streamer = TextStreamer(tokenizer, skip_prompt=True)
streamer.put = enqueue_token # intercept tokens
enriched_prompt = inject_relevant_kg(input.prompt)
generator(
enriched_prompt,
max_new_tokens=min(input.max_new_tokens, DEEPSEEK_MAX_TOKENS),
do_sample=False,
streamer=streamer
)
except Exception as e:
q.put(f"[ERROR] {e}")
finally:
q.put(None)
threading.Thread(target=run_generation, daemon=True).start()
async def event_generator():
while True:
token = q.get()
if token is None:
break
yield token
return StreamingResponse(event_generator(), media_type="text/plain")
# ------------------------
# Non-stream endpoint
# ------------------------
@app.post("/generate")
async def generate_text(input: ModelInput):
try:
enriched_prompt = inject_relevant_kg(input.prompt)
output = generator(
enriched_prompt,
max_new_tokens=min(input.max_new_tokens, DEEPSEEK_MAX_TOKENS),
do_sample=False
)[0]["generated_text"]
return {"generated_text": output}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ------------------------
# KG endpoint
# ------------------------
@app.get("/knowledge")
async def get_knowledge():
return knowledge_graph
# ------------------------
# Root endpoint
# ------------------------
@app.get("/")
async def root():
return {"message": "Welcome to the Streaming Model API with KG Updater!"}