|
|
|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from .llm_utils import simulate_search |
|
from .umls_linker import link_umls |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
import functools |
|
|
|
ANSWER_MODEL = "sunhaonlp/SearchSimulation_14B" |
|
|
|
@functools.lru_cache(maxsize=1) |
|
def _load_answer_pipe(): |
|
tokenizer = AutoTokenizer.from_pretrained(ANSWER_MODEL) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
ANSWER_MODEL, |
|
trust_remote_code=True, |
|
device_map="auto" |
|
) |
|
return pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=256, |
|
do_sample=False, |
|
temperature=0.0, |
|
) |
|
|
|
class Query(BaseModel): |
|
question: str |
|
|
|
app = FastAPI( |
|
title="ZeroSearch Medical Q&A API", |
|
description="Ask clinical questions; get answers with UMLS links, no external search APIs.", |
|
version="0.1.0", |
|
) |
|
|
|
@app.post("/ask") |
|
def ask(query: Query): |
|
docs = simulate_search(query.question, k=5) |
|
context = "\n\n".join(docs) |
|
prompt = ( |
|
"Answer the medical question strictly based on the provided context.\n\n" |
|
f"Context:\n{context}\n\n" |
|
f"Question: {query.question}\nAnswer:" |
|
) |
|
answer_pipe = _load_answer_pipe() |
|
answer = ( |
|
answer_pipe(prompt, num_return_sequences=1)[0]["generated_text"] |
|
.split("Answer:")[-1].strip() |
|
) |
|
umls = link_umls(answer) |
|
return {"answer": answer, "docs": docs, "umls": umls} |
|
|