Update app.py
Browse files
app.py
CHANGED
@@ -35,11 +35,12 @@ tokenizer = LlamaTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
|
35 |
|
36 |
# Check CUDA and enable Flash Attention if supported
|
37 |
use_flash_attention = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
|
|
|
38 |
model = LlamaForCausalLM.from_pretrained(
|
39 |
MODEL_ID,
|
40 |
torch_dtype=torch.bfloat16,
|
41 |
device_map="auto",
|
42 |
-
|
43 |
load_in_8bit=True
|
44 |
)
|
45 |
|
|
|
35 |
|
36 |
# Check CUDA and enable Flash Attention if supported
|
37 |
use_flash_attention = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
|
38 |
+
attn_implementation = "flash_attention_2" if use_flash_attention else "eager" # Default to eager if no compatible GPU
|
39 |
model = LlamaForCausalLM.from_pretrained(
|
40 |
MODEL_ID,
|
41 |
torch_dtype=torch.bfloat16,
|
42 |
device_map="auto",
|
43 |
+
attn_implementation=attn_implementation,
|
44 |
load_in_8bit=True
|
45 |
)
|
46 |
|