fixing bugs
Browse files
utils.py
CHANGED
|
@@ -5,6 +5,14 @@ import numpy as np
|
|
| 5 |
from functools import partial
|
| 6 |
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
def make_grid(patch_size: int | tuple[int, int]):
|
| 9 |
"""Gera grid de coordenadas com validação robusta"""
|
| 10 |
if isinstance(patch_size, int):
|
|
|
|
| 5 |
from functools import partial
|
| 6 |
|
| 7 |
|
| 8 |
+
def repeat_vmap(fun, in_axes=None):
|
| 9 |
+
if in_axes is None:
|
| 10 |
+
in_axes = [0]
|
| 11 |
+
for axes in in_axes:
|
| 12 |
+
fun = jax.vmap(fun, in_axes=axes)
|
| 13 |
+
return fun
|
| 14 |
+
|
| 15 |
+
|
| 16 |
def make_grid(patch_size: int | tuple[int, int]):
|
| 17 |
"""Gera grid de coordenadas com validação robusta"""
|
| 18 |
if isinstance(patch_size, int):
|