Spaces:
Runtime error
Runtime error
| import torch | |
| tensor_interpolation = None | |
| def get_tensor_interpolation_method(): | |
| return tensor_interpolation | |
| def set_tensor_interpolation_method(is_slerp): | |
| global tensor_interpolation | |
| tensor_interpolation = slerp if is_slerp else linear | |
| def linear(v1, v2, t): | |
| return (1.0 - t) * v1 + t * v2 | |
| def slerp( | |
| v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 | |
| ) -> torch.Tensor: | |
| u0 = v0 / v0.norm() | |
| u1 = v1 / v1.norm() | |
| dot = (u0 * u1).sum() | |
| if dot.abs() > DOT_THRESHOLD: | |
| # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') | |
| return (1.0 - t) * v0 + t * v1 | |
| omega = dot.acos() | |
| return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() | |