alexnasa commited on
Commit
6d76a3a
·
verified ·
1 Parent(s): 188aad8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -1
app.py CHANGED
@@ -4,9 +4,29 @@ import gc
4
  import os
5
  import subprocess
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def sh(cmd): subprocess.check_call(cmd, shell=True)
8
 
9
- sh("pip install torch-scatter -f https://data.pyg.org/whl/torch-2.7.0+1.2.6.html")
10
 
11
 
12
  import shutil
 
4
  import os
5
  import subprocess
6
 
7
+ def install_cuda_toolkit():
8
+ CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
9
+ CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
10
+ subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
11
+ subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
12
+ subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
13
+
14
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
15
+ os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
16
+ os.environ["CPATH"] = "%s/include:%s" % (os.environ["CUDA_HOME"], os.environ["CPATH"])
17
+ os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
18
+ os.environ["CUDA_HOME"],
19
+ "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
20
+ )
21
+ # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
22
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
23
+ print("==> finished installation")
24
+
25
+ install_cuda_toolkit()
26
+
27
  def sh(cmd): subprocess.check_call(cmd, shell=True)
28
 
29
+ sh("pip install torch-scatter")
30
 
31
 
32
  import shutil