File size: 1,854 Bytes
2bda614
b651070
876d145
2bda614
 
b651070
2bda614
 
 
e31fef3
b651070
 
876d145
 
 
b651070
876d145
eb2112e
876d145
 
 
 
b651070
876d145
e31fef3
b651070
 
 
 
e31fef3
b651070
 
 
86d4de7
e31fef3
b651070
dd74b32
a73c563
b651070
 
 
 
e31fef3
b651070
a73c563
e31fef3
b651070
 
a73c563
b651070
 
dd74b32
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
import pandas as pd, faiss, torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer

# β€”β€” Load data & embedding model β€”β€”
df = pd.read_csv("retrieval_corpus.csv")
index = faiss.read_index("faiss_index.bin")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

# β€”β€” Quantized BioMedLM with CPU offload β€”β€”
model_id = "stanford-crfm/BioMedLM"
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0,
    llm_int8_enable_fp32_cpu_offload=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
generation_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map={"": "cpu"},
)

def retrieve_top_k(q, k=5):
    emb = embedding_model.encode([q]).astype("float32")
    D,I = index.search(emb, k)
    res = df.iloc[I[0]].copy(); res["score"]=D[0]; return res

def build_prompt(q, docs):
    ctx = "\n".join(f"- {d['text']}" for _,d in docs.iterrows())
    return f"""[INST] <<SYS>>…[/INST]"""  # your existing template

def generate_local_answer(prompt, max_new_tokens=512):
    import time
    device = torch.device("cpu")
    start = time.time()
    inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
    out = generation_model.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        num_beams=1,
    )
    print(f"Gen time: {time.time()-start:.2f}s")
    return tokenizer.decode(out[0], skip_special_tokens=True)

iface = gr.Interface(fn=lambda q: generate_local_answer(build_prompt(q, retrieve_top_k(q))),
                     inputs="text", outputs="text")
iface.launch()