Stylique commited on
Commit
ae1a175
·
verified ·
1 Parent(s): e411375

Upload 2 files

Browse files
Files changed (1) hide show
  1. post_install.py +45 -8
post_install.py CHANGED
@@ -325,18 +325,23 @@ def install_pytorch3d():
325
 
326
  # Try the official PyTorch3D installation command for current CUDA version
327
  if PYTORCH_VERSION and CUDA_VERSION:
328
- # Try CUDA 11.7 wheels first (compatible with PyTorch 2.0.1)
329
- if CUDA_VERSION == "cu117":
330
- print("Trying PyTorch3D installation for CUDA 11.7...")
331
- if run_command("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu117_pyt201/download.html"):
332
- print("Successfully installed PyTorch3D for CUDA 11.7")
333
- return True
334
- # Try CUDA 12.6 wheels for newer PyTorch versions
335
- elif CUDA_VERSION == "cu126":
336
  print("Trying PyTorch3D installation for CUDA 12.6...")
337
  if run_command("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu126_pyt271/download.html"):
338
  print("Successfully installed PyTorch3D for CUDA 12.6")
339
  return True
 
 
 
 
 
 
 
 
 
 
 
340
 
341
  # Fallback to CUDA 11.7 wheels
342
  print("Trying official PyTorch3D installation (CUDA 11.7)...")
@@ -411,6 +416,38 @@ def install_pytorch3d():
411
 
412
  return False
413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  def install_pytorch_dependencies():
415
  """Install PyTorch-related dependencies"""
416
  print("Installing PyTorch dependencies...")
 
325
 
326
  # Try the official PyTorch3D installation command for current CUDA version
327
  if PYTORCH_VERSION and CUDA_VERSION:
328
+ # Try CUDA 12.6 wheels first (current system CUDA)
329
+ if CUDA_VERSION == "cu126":
 
 
 
 
 
 
330
  print("Trying PyTorch3D installation for CUDA 12.6...")
331
  if run_command("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu126_pyt271/download.html"):
332
  print("Successfully installed PyTorch3D for CUDA 12.6")
333
  return True
334
+ # Try CUDA 11.7 wheels for older PyTorch versions
335
+ elif CUDA_VERSION == "cu117":
336
+ print("Trying PyTorch3D installation for CUDA 11.7...")
337
+ if run_command("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu117_pyt201/download.html"):
338
+ print("Successfully installed PyTorch3D for CUDA 11.7")
339
+ return True
340
+
341
+ # Try to build PyTorch3D from source with system CUDA
342
+ print("Trying PyTorch3D source installation with system CUDA...")
343
+ if install_pytorch3d_from_source():
344
+ return True
345
 
346
  # Fallback to CUDA 11.7 wheels
347
  print("Trying official PyTorch3D installation (CUDA 11.7)...")
 
416
 
417
  return False
418
 
419
+ def install_pytorch3d_from_source():
420
+ """Install PyTorch3D from source with system CUDA"""
421
+ print("Installing PyTorch3D from source with system CUDA...")
422
+
423
+ packages_dir = Path("packages")
424
+
425
+ # Clone PyTorch3D
426
+ if not (packages_dir / "pytorch3d").exists():
427
+ if not run_command("git clone https://github.com/facebookresearch/pytorch3d.git", cwd=packages_dir):
428
+ return False
429
+
430
+ # Install build dependencies
431
+ print("Installing build dependencies...")
432
+ run_command("pip install wheel setuptools ninja")
433
+
434
+ # Install PyTorch3D with system CUDA
435
+ pytorch3d_dir = packages_dir / "pytorch3d"
436
+
437
+ # Set environment variables to use system CUDA
438
+ env = os.environ.copy()
439
+ env['FORCE_CUDA'] = '1'
440
+ env['CUDA_HOME'] = '/usr/local/cuda'
441
+ env['CUDA_VERSION'] = '12.6'
442
+ env['TORCH_CUDA_ARCH_LIST'] = '8.9' # For L4 GPU
443
+
444
+ print("Building PyTorch3D with system CUDA...")
445
+ if run_command("pip install . --no-build-isolation", cwd=pytorch3d_dir, env=env):
446
+ print("Successfully installed PyTorch3D from source with system CUDA")
447
+ return True
448
+
449
+ return False
450
+
451
  def install_pytorch_dependencies():
452
  """Install PyTorch-related dependencies"""
453
  print("Installing PyTorch dependencies...")