from fastapi import FastAPI from pydantic import BaseModel from .llm_utils import simulate_search from .umls_linker import link_umls from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import functools ANSWER_MODEL = "sunhaonlp/SearchSimulation_14B" @functools.lru_cache(maxsize=1) def _load_answer_pipe(): tokenizer = AutoTokenizer.from_pretrained(ANSWER_MODEL) model = AutoModelForCausalLM.from_pretrained( ANSWER_MODEL, trust_remote_code=True, device_map="auto" ) return pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256, do_sample=False, temperature=0.0, ) class Query(BaseModel): question: str app = FastAPI( title="ZeroSearch Medical Q&A API", description="Ask clinical questions; get answers with UMLS links, no external search APIs.", version="0.1.0", ) @app.post("/ask") def ask(query: Query): docs = simulate_search(query.question, k=5) context = "\n\n".join(docs) prompt = ( "Answer the medical question strictly based on the provided context.\n\n" f"Context:\n{context}\n\n" f"Question: {query.question}\nAnswer:" ) answer_pipe = _load_answer_pipe() answer = ( answer_pipe(prompt, num_return_sequences=1)[0]["generated_text"] .split("Answer:")[-1].strip() ) umls = link_umls(answer) return {"answer": answer, "docs": docs, "umls": umls}