Quaz1 commited on
Commit
af14831
·
1 Parent(s): 171524c

added install flash attn

Browse files
Files changed (2) hide show
  1. app.py +42 -0
  2. setup.sh +8 -0
app.py CHANGED
@@ -3,6 +3,48 @@ import torch
3
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  import scipy.io.wavfile
5
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Load model and processor
8
  @gr.cache()
 
3
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  import scipy.io.wavfile
5
  import numpy as np
6
+ import subprocess
7
+ import sys
8
+ import os
9
+
10
+ def setup_flash_attention():
11
+ """One-time setup for flash-attention with special flags"""
12
+ # Check if flash-attn is already installed
13
+ try:
14
+ import flash_attn
15
+ print("flash-attn already installed")
16
+ return
17
+ except ImportError:
18
+ pass
19
+
20
+ # Check if we've already tried to install it in this session
21
+ if os.path.exists("/tmp/flash_attn_installed"):
22
+ return
23
+
24
+ try:
25
+ print("Installing flash-attn with --no-build-isolation...")
26
+ subprocess.run([
27
+ sys.executable, "-m", "pip", "install",
28
+ "flash-attn==2.7.3", "--no-build-isolation"
29
+ ], check=True)
30
+
31
+ # Uninstall apex if it exists
32
+ subprocess.run([
33
+ sys.executable, "-m", "pip", "uninstall", "apex", "-y"
34
+ ], check=False) # Don't fail if apex isn't installed
35
+
36
+ # Mark as installed
37
+ with open("/tmp/flash_attn_installed", "w") as f:
38
+ f.write("installed")
39
+
40
+ print("flash-attn installation completed")
41
+
42
+ except subprocess.CalledProcessError as e:
43
+ print(f"Warning: Failed to install flash-attn: {e}")
44
+ # Continue anyway - the model might work without it
45
+
46
+ # Run setup once when the module is imported
47
+ setup_flash_attention()
48
 
49
  # Load model and processor
50
  @gr.cache()
setup.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # Install flash-attn with no-build-isolation
5
+ pip install flash-attn==2.7.3 --no-build-isolation
6
+
7
+ # Uninstall apex if it exists (as in your original script)
8
+ pip uninstall apex -y || true