MedQA / backend /llm_utils.py
mgbam's picture
Upload 4 files
3ef03d3 verified
"""Utilities for loading the ZeroSearch simulation model and performing simulated searches."""
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import functools
MODEL_NAME = "sunhaonlp/SearchSimulation_14B"
@functools.lru_cache(maxsize=1)
def _load_search_pipe():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
device_map="auto"
)
return pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=False,
temperature=0.0,
)
def simulate_search(query: str, k: int = 5):
"""Generate *k* synthetic documents for *query*."""
pipe = _load_search_pipe()
prompt = f"SearchSimulation:\nQuery: {query}\nDocuments:"
outputs = pipe(prompt, num_return_sequences=k)
docs = []
for o in outputs:
text = o["generated_text"]
docs.append(text.split("Documents:")[-1].strip())
return docs