Update app.py
Browse files
app.py
CHANGED
@@ -89,7 +89,7 @@ vectordb = FAISS.load_local(CFG.Output_folder + '/faiss_index_ml_papers', embedd
|
|
89 |
@spaces.GPU
|
90 |
def build_model(model_repo = CFG.model_name):
|
91 |
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
92 |
-
model = AutoModelForCausalLM.from_pretrained(model_repo, attn_implementation="flash_attention_2")
|
93 |
|
94 |
return tokenizer, model
|
95 |
|
|
|
89 |
@spaces.GPU
|
90 |
def build_model(model_repo = CFG.model_name):
|
91 |
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
92 |
+
model = AutoModelForCausalLM.from_pretrained(model_repo, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
|
93 |
|
94 |
return tokenizer, model
|
95 |
|