KoFace-AI / helper /util.py
JuyeopDang's picture
Upload 35 files
5ab5cab verified
raw
history blame contribute delete
129 Bytes
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))