File size: 1,114 Bytes
2bf5660 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import random
import math
import torch
def random_circle_init(min_r : float = 0.5, on_edge : bool = False):
theta = random.uniform(0, 2 * math.pi)
if on_edge:
r = 1.0
else:
r = random.uniform(min_r, 1.0)
x = r * math.cos(theta)
y = r * math.sin(theta)
return x, y
def recursive_find_dtype(x):
"""
Assuming x is some list/tuple of things that could be tensors, searches for any tensors and returns dtype
"""
for i in x:
if isinstance(i, list):
res = recursive_find_dtype(i)
if res is None:
continue
else:
return res
elif isinstance(i, torch.Tensor):
return i.dtype
def recursive_find_device(x):
"""
Assuming x is some list/tuple of things that could be tensors, searches for any tensors and returns device
"""
for i in x:
if isinstance(i, list):
res = recursive_find_device(i)
if res is None:
continue
return res
elif isinstance(i, torch.Tensor):
return i.device
|