thecollabagepatch commited on
Commit
5081ad0
·
1 Parent(s): 079ca7d

ok uv lets see

Browse files
Files changed (2) hide show
  1. Dockerfile +41 -141
  2. pyproject.toml +42 -0
Dockerfile CHANGED
@@ -1,148 +1,48 @@
1
- # thecollabagepatch/magenta:latest
2
- FROM nvidia/cuda:12.6.2-cudnn-runtime-ubuntu22.04
3
 
4
- # CUDA libs present + on loader path
5
  RUN apt-get update && apt-get install -y --no-install-recommends \
6
- cuda-libraries-12-4 && rm -rf /var/lib/apt/lists/*
7
- ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda-12.4/lib64:/usr/local/cuda-12.4/compat:/usr/local/cuda/targets/x86_64-linux/lib:${LD_LIBRARY_PATH}
8
- RUN ln -sf /usr/local/cuda/targets/x86_64-linux/lib /usr/local/cuda/lib64 || true
9
-
10
- # Ensure the NVIDIA repo key is present (non-interactive) and install cuDNN 9.8
11
- RUN set -eux; \
12
- apt-get update && apt-get install -y --no-install-recommends gnupg ca-certificates curl; \
13
- install -d -m 0755 /usr/share/keyrings; \
14
- # Refresh the *same* keyring the base source uses (no second source file)
15
- curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub \
16
- | gpg --batch --yes --dearmor -o /usr/share/keyrings/cuda-archive-keyring.gpg; \
17
- apt-get update; \
18
- # If libcudnn is "held", unhold it so we can move to 9.8
19
- apt-mark unhold libcudnn9-cuda-12 || true; \
20
- # Install cuDNN 9.8 for CUDA 12 (correct dev package name!)
21
- apt-get install -y --no-install-recommends \
22
- 'libcudnn9-cuda-12=9.8.*' \
23
- 'libcudnn9-dev-cuda-12=9.8.*' \
24
- --allow-downgrades --allow-change-held-packages; \
25
- apt-mark hold libcudnn9-cuda-12 || true; \
26
- ldconfig; \
27
- rm -rf /var/lib/apt/lists/*
28
-
29
- # (optional) preload workaround if still needed
30
- ENV LD_PRELOAD=/usr/local/cuda/lib64/libcusparse.so.12:/usr/local/cuda/lib64/libcublas.so.12:/usr/local/cuda/lib64/libcublasLt.so.12:/usr/local/cuda/lib64/libcufft.so.11:/usr/local/cuda/lib64/libcusolver.so.11
31
-
32
- # Better allocator (less fragmentation than BFC during XLA autotune)
33
- ENV TF_GPU_ALLOCATOR=cuda_malloc_async
34
-
35
- # Let cuBLAS use TF32 fast path on Ada (L40S) for big GEMMs
36
- ENV TF_ENABLE_CUBLAS_TF32=1 NVIDIA_TF32_OVERRIDE=1
37
-
38
- ENV DEBIAN_FRONTEND=noninteractive \
39
- PYTHONUNBUFFERED=1 \
40
- PIP_NO_CACHE_DIR=1 \
41
- TF_FORCE_GPU_ALLOW_GROWTH=true \
42
- XLA_PYTHON_CLIENT_PREALLOCATE=false
43
-
44
- ENV JAX_PLATFORMS=""
45
-
46
- # --- OS deps ---
47
- RUN apt-get update && apt-get install -y --no-install-recommends \
48
- software-properties-common curl ca-certificates git \
49
- libsndfile1 ffmpeg \
50
- build-essential pkg-config \
51
- && add-apt-repository ppa:deadsnakes/ppa -y \
52
- && apt-get update && apt-get install -y --no-install-recommends \
53
  python3.11 python3.11-venv python3.11-distutils python3-pip \
54
- && rm -rf /var/lib/apt/lists/*
55
-
56
- # Make python3 => 3.11 for convenience
57
- RUN ln -sf /usr/bin/python3.11 /usr/bin/python && python -m pip install --upgrade pip
58
-
59
- # --- Python deps (pin order matters!) ---
60
- # 1) JAX CUDA pins
61
- RUN python -m pip install "jax[cuda12]==0.6.2" "jaxlib==0.6.2"
62
-
63
- # 2) Lock seqio early to avoid backtracking madness
64
- RUN python -m pip install "seqio==0.0.11"
65
-
66
- # 3) Install Magenta RT *without* deps so we control pins
67
- RUN python -m pip install --no-deps 'git+https://github.com/magenta/magenta-realtime#egg=magenta_rt[gpu]'
68
-
69
- # 4) TF nightlies (MATCH DATES!)
70
- RUN python -m pip install \
71
- "tf_nightly==2.20.0.dev20250619" \
72
- "tensorflow-text-nightly==2.20.0.dev20250316" \
73
- "tf-hub-nightly"
74
-
75
- # 5) tf2jax pinned alongside tf_nightly so pip doesn’t drag stable TF
76
- RUN python -m pip install tf2jax "tf_nightly==2.20.0.dev20250619"
77
-
78
- # 6) The rest of MRT deps + API runtime deps
79
- RUN python -m pip install \
80
- gin-config librosa resampy soundfile \
81
- google-auth google-auth-oauthlib google-auth-httplib2 \
82
- google-api-core googleapis-common-protos google-resumable-media \
83
- google-cloud-storage requests tqdm typing-extensions numpy==2.1.3 \
84
- fastapi uvicorn[standard] python-multipart pyloudnorm
85
-
86
- # 7) Exact commits for T5X/Flaxformer as in pyproject
87
- RUN python -m pip install \
88
- "t5x @ git+https://github.com/google-research/t5x.git@92c5b46" \
89
- "flaxformer @ git+https://github.com/google/flaxformer@399ea3a"
90
-
91
- # ---- FINAL: enforce TF nightlies and clean any stable TF ----
92
- RUN python - <<'PY'
93
- import sys, sysconfig, glob, os, shutil
94
- # Find a writable site dir (site-packages OR dist-packages)
95
- cands = [sysconfig.get_paths().get('purelib'), sysconfig.get_paths().get('platlib')]
96
- cands += [p for p in sys.path if p and p.endswith(('site-packages','dist-packages'))]
97
- site = next(p for p in cands if p and os.path.isdir(p))
98
-
99
- patterns = [
100
- "tensorflow", "tensorflow-*.dist-info", "tensorflow-*.egg-info",
101
- "tf-nightly-*.dist-info", "tf_nightly-*.dist-info",
102
- "tensorflow_text", "tensorflow_text-*.dist-info",
103
- "tf-hub-nightly-*.dist-info", "tf_hub_nightly-*.dist-info",
104
- "tf_keras-nightly-*.dist-info", "tf_keras_nightly-*.dist-info",
105
- "tensorboard*", "tb-nightly-*.dist-info",
106
- "keras*", # remove stray keras
107
- "tensorflow_hub*", "tensorflow_io*",
108
- ]
109
- for pat in patterns:
110
- for path in glob.glob(os.path.join(site, pat)):
111
- if os.path.isdir(path): shutil.rmtree(path, ignore_errors=True)
112
- else:
113
- try: os.remove(path)
114
- except FileNotFoundError: pass
115
-
116
- print("TF/Hub/Text cleared in:", site)
117
  PY
118
 
119
- # Reinstall pinned nightlies in ONE transaction
120
- RUN python -m pip install --no-cache-dir --force-reinstall \
121
- "tf-nightly==2.20.0.dev20250619" \
122
- "tensorflow-text-nightly==2.20.0.dev20250316" \
123
- "tf-hub-nightly"
124
-
125
- RUN python -m pip install huggingface_hub
126
-
127
- RUN python -m pip install --no-cache-dir --force-reinstall "protobuf==4.25.3"
128
-
129
- RUN python -m pip install gradio
130
-
131
-
132
-
133
- # Switch to Spaces’ preferred user
134
- # Switch to Spaces’ preferred user
135
- RUN useradd -m -u 1000 appuser
136
- WORKDIR /home/appuser/app
137
-
138
- # Copy from *build context* into image, owned by appuser
139
- COPY --chown=appuser:appuser app.py /home/appuser/app/app.py
140
-
141
- # NEW: shared utils + worker
142
- COPY --chown=appuser:appuser utils.py /home/appuser/app/utils.py
143
- COPY --chown=appuser:appuser jam_worker.py /home/appuser/app/jam_worker.py
144
-
145
- USER appuser
146
 
147
  EXPOSE 7860
148
- CMD ["bash", "-lc", "python -m uvicorn app:app --host 0.0.0.0 --port ${PORT:-7860}"]
 
1
+ FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
2
+ # ^ pick 12.4 OR 12.6 everywhere; 12.4 shown for consistency with your LD paths
3
 
4
+ # OS deps
5
  RUN apt-get update && apt-get install -y --no-install-recommends \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  python3.11 python3.11-venv python3.11-distutils python3-pip \
7
+ libsndfile1 ffmpeg git ca-certificates curl \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # install uv
11
+ RUN curl -LsSf https://astral.sh/uv/install.sh | sh -s -- -y
12
+ ENV PATH="/root/.local/bin:${PATH}"
13
+
14
+ # CUDA loader path (avoid hard pin to a different minor)
15
+ ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH}
16
+
17
+ # TF/GPU niceties
18
+ ENV TF_FORCE_GPU_ALLOW_GROWTH=true \
19
+ XLA_PYTHON_CLIENT_PREALLOCATE=false \
20
+ JAX_PLATFORMS=cuda,cpu
21
+
22
+ # copy project manifest and lock it deterministically
23
+ WORKDIR /opt/app
24
+ COPY pyproject.toml ./
25
+
26
+ # produce a lock (or check in uv.lock and just COPY it instead)
27
+ RUN uv lock
28
+
29
+ # sync deps into a venv at /opt/venv (fast, reproducible)
30
+ RUN uv sync --frozen --python=/usr/bin/python3.11 --no-dev
31
+
32
+ # show JAX versions (build-time sanity)
33
+ RUN /opt/venv/bin/python - <<'PY'
34
+ import jax, jaxlib
35
+ print("JAX:", jax.__version__)
36
+ print("JAXLIB:", jaxlib.__version__)
37
+ try:
38
+ import importlib
39
+ print("CUDA plugin:", importlib.metadata.version("jax-cuda12-plugin"))
40
+ except Exception as e:
41
+ print("CUDA plugin:", "not found?", e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  PY
43
 
44
+ # app files
45
+ COPY app.py utils.py jam_worker.py ./
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  EXPOSE 7860
48
+ CMD ["/opt/venv/bin/uvicorn","app:app","--host","0.0.0.0","--port","7860"]
pyproject.toml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "mrt-app"
3
+ version = "0.0.1"
4
+ requires-python = ">=3.11"
5
+
6
+ dependencies = [
7
+ # core, pinned to avoid drift
8
+ "numpy==2.1.3",
9
+ "seqio==0.0.11",
10
+ "gin-config",
11
+ "librosa",
12
+ "resampy",
13
+ "soundfile",
14
+ "tqdm",
15
+ "typing-extensions",
16
+ "requests",
17
+ "fastapi",
18
+ "uvicorn[standard]",
19
+ "pyloudnorm",
20
+ "protobuf==4.25.3",
21
+
22
+ # tensorflow nightlies (your chosen dates)
23
+ "tf-nightly==2.20.0.dev20250619",
24
+ "tensorflow-text-nightly==2.20.0.dev20250316",
25
+ "tf-hub-nightly",
26
+ "tf2jax",
27
+
28
+ # research libs at fixed commits
29
+ "t5x @ git+https://github.com/google-research/t5x.git@92c5b46",
30
+ "flaxformer @ git+https://github.com/google/flaxformer@399ea3a",
31
+
32
+ # install Magenta RT from git with the GPU extra
33
+ "magenta_rt[gpu] @ git+https://github.com/magenta/magenta-realtime",
34
+ ]
35
+
36
+ [tool.uv] # optional, but handy
37
+ override = [
38
+ # keep the JAX triplet in lockstep explicitly
39
+ "jax==0.7.1",
40
+ "jaxlib==0.7.1",
41
+ "jax-cuda12-plugin==0.7.1",
42
+ ]