File size: 129 Bytes
5ab5cab
 
 
 
1
2
3
4
def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))