Spaces:
Sleeping
Sleeping
Commit
·
a73c563
1
Parent(s):
876d145
updated.
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
import faiss
|
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
@@ -54,6 +55,8 @@ You are a medical assistant trained on clinical reasoning data. Given the follow
|
|
54 |
def generate_local_answer(prompt, max_new_tokens=512):
|
55 |
device = torch.device("cpu")
|
56 |
print(f"Using device: {device}")
|
|
|
|
|
57 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
|
58 |
input_ids = inputs["input_ids"].to(device)
|
59 |
attention_mask = inputs["attention_mask"].to(device)
|
@@ -62,12 +65,12 @@ def generate_local_answer(prompt, max_new_tokens=512):
|
|
62 |
input_ids=input_ids,
|
63 |
attention_mask=attention_mask,
|
64 |
max_new_tokens=max_new_tokens,
|
65 |
-
|
66 |
-
|
67 |
-
top_k=50,
|
68 |
-
top_p=0.95,
|
69 |
)
|
|
|
70 |
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
|
|
|
71 |
return decoded.split("### Diagnostic Explanation:")[-1].strip()
|
72 |
|
73 |
def rag_chat(query):
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
import faiss
|
4 |
+
import time
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
|
|
55 |
def generate_local_answer(prompt, max_new_tokens=512):
|
56 |
device = torch.device("cpu")
|
57 |
print(f"Using device: {device}")
|
58 |
+
start = time.time()
|
59 |
+
|
60 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
|
61 |
input_ids = inputs["input_ids"].to(device)
|
62 |
attention_mask = inputs["attention_mask"].to(device)
|
|
|
65 |
input_ids=input_ids,
|
66 |
attention_mask=attention_mask,
|
67 |
max_new_tokens=max_new_tokens,
|
68 |
+
do_sample=False, # ← GREEDY
|
69 |
+
num_beams=1,
|
|
|
|
|
70 |
)
|
71 |
+
|
72 |
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
|
73 |
+
print(f"Time taken: {time.time() - start:.2f}s")
|
74 |
return decoded.split("### Diagnostic Explanation:")[-1].strip()
|
75 |
|
76 |
def rag_chat(query):
|