asadsandhu commited on
Commit
0619f86
·
1 Parent(s): ff04da1
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -30,12 +30,14 @@ bnb_config = BitsAndBytesConfig(
30
  )
31
 
32
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
33
  generation_model = AutoModelForCausalLM.from_pretrained(
34
  model_id,
35
- torch_dtype=torch.float16,
36
- device_map="auto",
37
- quantization_config=bnb_config
38
- )
39
 
40
  # ----------------------
41
  # RAG Functions
@@ -69,7 +71,8 @@ You are a medical assistant trained on clinical reasoning data. Given the follow
69
  return prompt
70
 
71
  def generate_local_answer(prompt, max_new_tokens=512):
72
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
 
73
  output = generation_model.generate(
74
  input_ids=input_ids,
75
  max_new_tokens=max_new_tokens,
 
30
  )
31
 
32
  tokenizer = AutoTokenizer.from_pretrained(model_id)
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+
35
  generation_model = AutoModelForCausalLM.from_pretrained(
36
  model_id,
37
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
38
+ device_map="auto" if torch.cuda.is_available() else None,
39
+ quantization_config=bnb_config if torch.cuda.is_available() else None
40
+ ).to(device)
41
 
42
  # ----------------------
43
  # RAG Functions
 
71
  return prompt
72
 
73
  def generate_local_answer(prompt, max_new_tokens=512):
74
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
76
  output = generation_model.generate(
77
  input_ids=input_ids,
78
  max_new_tokens=max_new_tokens,