Commit
·
9a1b4dc
1
Parent(s):
6990691
coords fix not gonna work
Browse files
app.py
CHANGED
@@ -21,6 +21,42 @@ import logging
|
|
21 |
import gradio as gr
|
22 |
from typing import Optional
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def create_documentation_interface():
|
25 |
"""Create a Gradio interface for documentation and transparency"""
|
26 |
|
|
|
21 |
import gradio as gr
|
22 |
from typing import Optional
|
23 |
|
24 |
+
# --- Patch T5X mesh helpers for GPUs on JAX >= 0.7 (coords present, no core_on_chip) ---
|
25 |
+
def _patch_t5x_for_gpu_coords():
|
26 |
+
try:
|
27 |
+
import jax
|
28 |
+
from t5x import partitioning as _t5x_part
|
29 |
+
|
30 |
+
old_bounds = getattr(_t5x_part, "bounds_from_last_device", None)
|
31 |
+
old_getcoords = getattr(_t5x_part, "get_coords", None)
|
32 |
+
|
33 |
+
def _bounds_from_last_device_gpu_safe(last_device):
|
34 |
+
# TPU: coords + core_on_chip
|
35 |
+
core = getattr(last_device, "core_on_chip", None)
|
36 |
+
coords = getattr(last_device, "coords", None)
|
37 |
+
if coords is not None and core is not None:
|
38 |
+
x, y, z = coords
|
39 |
+
return x + 1, y + 1, z + 1, core + 1
|
40 |
+
# Non-TPU (or GPU lacking core_on_chip): hosts x local_devices
|
41 |
+
return jax.host_count(), jax.local_device_count()
|
42 |
+
|
43 |
+
def _get_coords_gpu_safe(device):
|
44 |
+
core = getattr(device, "core_on_chip", None)
|
45 |
+
coords = getattr(device, "coords", None)
|
46 |
+
if coords is not None and core is not None:
|
47 |
+
return (*coords, core)
|
48 |
+
# Fallback that works on CPU/GPU
|
49 |
+
return (device.process_index, device.id % jax.local_device_count())
|
50 |
+
|
51 |
+
_t5x_part.bounds_from_last_device = _bounds_from_last_device_gpu_safe
|
52 |
+
_t5x_part.get_coords = _get_coords_gpu_safe
|
53 |
+
import logging; logging.info("Patched t5x.partitioning for GPU coords without core_on_chip.")
|
54 |
+
except Exception as e:
|
55 |
+
import logging; logging.exception("t5x GPU-coords patch failed: %s", e)
|
56 |
+
|
57 |
+
# Call the patch immediately at import time (before MagentaRT init)
|
58 |
+
_patch_t5x_for_gpu_coords()
|
59 |
+
|
60 |
def create_documentation_interface():
|
61 |
"""Create a Gradio interface for documentation and transparency"""
|
62 |
|