asadsandhu commited on
Commit
aeaead2
·
1 Parent(s): e31fef3
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -69,7 +69,9 @@ You are a medical assistant trained on clinical reasoning data. Given the follow
69
  return prompt
70
 
71
  def generate_local_answer(prompt, max_new_tokens=512):
72
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
 
 
73
  output = generation_model.generate(
74
  input_ids=input_ids,
75
  max_new_tokens=max_new_tokens,
 
69
  return prompt
70
 
71
  def generate_local_answer(prompt, max_new_tokens=512):
72
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
+ print(f"Using device: {device}")
74
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
75
  output = generation_model.generate(
76
  input_ids=input_ids,
77
  max_new_tokens=max_new_tokens,