Cylanoid commited on
Commit
c417ed1
·
verified ·
1 Parent(s): 2e7ec4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -35,11 +35,12 @@ tokenizer = LlamaTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
35
 
36
  # Check CUDA and enable Flash Attention if supported
37
  use_flash_attention = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
 
38
  model = LlamaForCausalLM.from_pretrained(
39
  MODEL_ID,
40
  torch_dtype=torch.bfloat16,
41
  device_map="auto",
42
- use_flash_attention_2=use_flash_attention,
43
  load_in_8bit=True
44
  )
45
 
 
35
 
36
  # Check CUDA and enable Flash Attention if supported
37
  use_flash_attention = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
38
+ attn_implementation = "flash_attention_2" if use_flash_attention else "eager" # Default to eager if no compatible GPU
39
  model = LlamaForCausalLM.from_pretrained(
40
  MODEL_ID,
41
  torch_dtype=torch.bfloat16,
42
  device_map="auto",
43
+ attn_implementation=attn_implementation,
44
  load_in_8bit=True
45
  )
46