thecollabagepatch commited on
Commit
9a1b4dc
·
1 Parent(s): 6990691

coords fix not gonna work

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py CHANGED
@@ -21,6 +21,42 @@ import logging
21
  import gradio as gr
22
  from typing import Optional
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def create_documentation_interface():
25
  """Create a Gradio interface for documentation and transparency"""
26
 
 
21
  import gradio as gr
22
  from typing import Optional
23
 
24
+ # --- Patch T5X mesh helpers for GPUs on JAX >= 0.7 (coords present, no core_on_chip) ---
25
+ def _patch_t5x_for_gpu_coords():
26
+ try:
27
+ import jax
28
+ from t5x import partitioning as _t5x_part
29
+
30
+ old_bounds = getattr(_t5x_part, "bounds_from_last_device", None)
31
+ old_getcoords = getattr(_t5x_part, "get_coords", None)
32
+
33
+ def _bounds_from_last_device_gpu_safe(last_device):
34
+ # TPU: coords + core_on_chip
35
+ core = getattr(last_device, "core_on_chip", None)
36
+ coords = getattr(last_device, "coords", None)
37
+ if coords is not None and core is not None:
38
+ x, y, z = coords
39
+ return x + 1, y + 1, z + 1, core + 1
40
+ # Non-TPU (or GPU lacking core_on_chip): hosts x local_devices
41
+ return jax.host_count(), jax.local_device_count()
42
+
43
+ def _get_coords_gpu_safe(device):
44
+ core = getattr(device, "core_on_chip", None)
45
+ coords = getattr(device, "coords", None)
46
+ if coords is not None and core is not None:
47
+ return (*coords, core)
48
+ # Fallback that works on CPU/GPU
49
+ return (device.process_index, device.id % jax.local_device_count())
50
+
51
+ _t5x_part.bounds_from_last_device = _bounds_from_last_device_gpu_safe
52
+ _t5x_part.get_coords = _get_coords_gpu_safe
53
+ import logging; logging.info("Patched t5x.partitioning for GPU coords without core_on_chip.")
54
+ except Exception as e:
55
+ import logging; logging.exception("t5x GPU-coords patch failed: %s", e)
56
+
57
+ # Call the patch immediately at import time (before MagentaRT init)
58
+ _patch_t5x_for_gpu_coords()
59
+
60
  def create_documentation_interface():
61
  """Create a Gradio interface for documentation and transparency"""
62