justinj92 commited on
Commit
d2f3905
·
verified ·
1 Parent(s): 95f5955

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -89,7 +89,7 @@ vectordb = FAISS.load_local(CFG.Output_folder + '/faiss_index_ml_papers', embedd
89
  @spaces.GPU
90
  def build_model(model_repo = CFG.model_name):
91
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
92
- model = AutoModelForCausalLM.from_pretrained(model_repo, attn_implementation="flash_attention_2")
93
 
94
  return tokenizer, model
95
 
 
89
  @spaces.GPU
90
  def build_model(model_repo = CFG.model_name):
91
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
92
+ model = AutoModelForCausalLM.from_pretrained(model_repo, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
93
 
94
  return tokenizer, model
95