asadsandhu commited on
Commit
a73c563
·
1 Parent(s): 876d145
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import faiss
 
4
  import numpy as np
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
@@ -54,6 +55,8 @@ You are a medical assistant trained on clinical reasoning data. Given the follow
54
  def generate_local_answer(prompt, max_new_tokens=512):
55
  device = torch.device("cpu")
56
  print(f"Using device: {device}")
 
 
57
  inputs = tokenizer(prompt, return_tensors="pt", padding=True)
58
  input_ids = inputs["input_ids"].to(device)
59
  attention_mask = inputs["attention_mask"].to(device)
@@ -62,12 +65,12 @@ def generate_local_answer(prompt, max_new_tokens=512):
62
  input_ids=input_ids,
63
  attention_mask=attention_mask,
64
  max_new_tokens=max_new_tokens,
65
- temperature=0.5,
66
- do_sample=True,
67
- top_k=50,
68
- top_p=0.95,
69
  )
 
70
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
 
71
  return decoded.split("### Diagnostic Explanation:")[-1].strip()
72
 
73
  def rag_chat(query):
 
1
  import gradio as gr
2
  import pandas as pd
3
  import faiss
4
+ import time
5
  import numpy as np
6
  import torch
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
55
  def generate_local_answer(prompt, max_new_tokens=512):
56
  device = torch.device("cpu")
57
  print(f"Using device: {device}")
58
+ start = time.time()
59
+
60
  inputs = tokenizer(prompt, return_tensors="pt", padding=True)
61
  input_ids = inputs["input_ids"].to(device)
62
  attention_mask = inputs["attention_mask"].to(device)
 
65
  input_ids=input_ids,
66
  attention_mask=attention_mask,
67
  max_new_tokens=max_new_tokens,
68
+ do_sample=False, # ← GREEDY
69
+ num_beams=1,
 
 
70
  )
71
+
72
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
73
+ print(f"Time taken: {time.time() - start:.2f}s")
74
  return decoded.split("### Diagnostic Explanation:")[-1].strip()
75
 
76
  def rag_chat(query):